summaryrefslogtreecommitdiff
path: root/become_yukarin/loss.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/loss.py')
-rw-r--r--become_yukarin/loss.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/become_yukarin/loss.py b/become_yukarin/loss.py
index c088691..3d89908 100644
--- a/become_yukarin/loss.py
+++ b/become_yukarin/loss.py
@@ -14,11 +14,16 @@ class Loss(chainer.link.Chain):
with self.init_scope():
self.predictor = predictor
- def __call__(self, input, target):
+ def __call__(self, input, target, mask):
+ input = chainer.as_variable(input)
+ target = chainer.as_variable(target)
+ mask = chainer.as_variable(mask)
+
h = input
y = self.predictor(h)
- loss = chainer.functions.mean_absolute_error(y, target)
+ loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask)
+ loss = loss / chainer.functions.sum(mask)
reporter.report({'loss': loss}, self)
return loss * self.config.l1