summaryrefslogtreecommitdiff
path: root/become_yukarin/loss.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/loss.py')
-rw-r--r--become_yukarin/loss.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/become_yukarin/loss.py b/become_yukarin/loss.py
index c59747a..b2b03fc 100644
--- a/become_yukarin/loss.py
+++ b/become_yukarin/loss.py
@@ -7,7 +7,7 @@ from .model import Predictor
class Loss(chainer.link.Chain):
- def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner):
+ def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner = None):
super().__init__()
self.config = config
@@ -21,7 +21,8 @@ class Loss(chainer.link.Chain):
mask = chainer.as_variable(mask)
h = input
- h = self.aligner(h)
+ if self.aligner is not None:
+ h = self.aligner(h)
y = self.predictor(h)
loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask)