summaryrefslogtreecommitdiff
path: root/become_yukarin/loss.py
blob: c0886912d6a95167822906136f2103620ea7911e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from .config import LossConfig
from .model import Model

import chainer

from chainer import reporter


class Loss(chainer.link.Chain):
    def __init__(self, config: LossConfig, predictor: Model):
        super().__init__()
        self.config = config

        with self.init_scope():
            self.predictor = predictor

    def __call__(self, input, target):
        h = input
        y = self.predictor(h)

        loss = chainer.functions.mean_absolute_error(y, target)
        reporter.report({'loss': loss}, self)

        return loss * self.config.l1