diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-22 23:50:31 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-22 23:50:31 +0900 |
| commit | 9f87a74de09e38f9d8f3e7ebb5fd26fac44a3b0e (patch) | |
| tree | ae662b5319256e3864877cacbd21c527f33448f0 /become_yukarin/loss.py | |
| parent | d6af2a851644afe253b97461b35138011a479a95 (diff) | |
can remove aligner
Diffstat (limited to 'become_yukarin/loss.py')
| -rw-r--r-- | become_yukarin/loss.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/become_yukarin/loss.py b/become_yukarin/loss.py index c59747a..b2b03fc 100644 --- a/become_yukarin/loss.py +++ b/become_yukarin/loss.py @@ -7,7 +7,7 @@ from .model import Predictor class Loss(chainer.link.Chain): - def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner): + def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner = None): super().__init__() self.config = config @@ -21,7 +21,8 @@ class Loss(chainer.link.Chain): mask = chainer.as_variable(mask) h = input - h = self.aligner(h) + if self.aligner is not None: + h = self.aligner(h) y = self.predictor(h) loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask) |
