diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-15 10:39:30 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-15 10:39:30 +0900 |
| commit | 8f3476bb0eb1a502772858a46b66903ca277456e (patch) | |
| tree | 2ee589b49605f2c7042428c7ce59b370407a3d6e /become_yukarin/loss.py | |
| parent | 4003ae8e457905070b789b75c5972ca93cc5756b (diff) | |
mask
Diffstat (limited to 'become_yukarin/loss.py')
| -rw-r--r-- | become_yukarin/loss.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/become_yukarin/loss.py b/become_yukarin/loss.py index c088691..3d89908 100644 --- a/become_yukarin/loss.py +++ b/become_yukarin/loss.py @@ -14,11 +14,16 @@ class Loss(chainer.link.Chain): with self.init_scope(): self.predictor = predictor - def __call__(self, input, target): + def __call__(self, input, target, mask): + input = chainer.as_variable(input) + target = chainer.as_variable(target) + mask = chainer.as_variable(mask) + h = input y = self.predictor(h) - loss = chainer.functions.mean_absolute_error(y, target) + loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask) + loss = loss / chainer.functions.sum(mask) reporter.report({'loss': loss}, self) return loss * self.config.l1 |
