diff options
Diffstat (limited to 'become_yukarin/loss.py')
| -rw-r--r-- | become_yukarin/loss.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/become_yukarin/loss.py b/become_yukarin/loss.py index c59747a..b2b03fc 100644 --- a/become_yukarin/loss.py +++ b/become_yukarin/loss.py @@ -7,7 +7,7 @@ from .model import Predictor class Loss(chainer.link.Chain): - def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner): + def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner = None): super().__init__() self.config = config @@ -21,7 +21,8 @@ class Loss(chainer.link.Chain): mask = chainer.as_variable(mask) h = input - h = self.aligner(h) + if self.aligner is not None: + h = self.aligner(h) y = self.predictor(h) loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask) |
