diff options
Diffstat (limited to 'become_yukarin/loss.py')
| -rw-r--r-- | become_yukarin/loss.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/become_yukarin/loss.py b/become_yukarin/loss.py index c088691..3d89908 100644 --- a/become_yukarin/loss.py +++ b/become_yukarin/loss.py @@ -14,11 +14,16 @@ class Loss(chainer.link.Chain): with self.init_scope(): self.predictor = predictor - def __call__(self, input, target): + def __call__(self, input, target, mask): + input = chainer.as_variable(input) + target = chainer.as_variable(target) + mask = chainer.as_variable(mask) + h = input y = self.predictor(h) - loss = chainer.functions.mean_absolute_error(y, target) + loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask) + loss = loss / chainer.functions.sum(mask) reporter.report({'loss': loss}, self) return loss * self.config.l1 |
