summaryrefslogtreecommitdiff
path: root/become_yukarin/loss.py
blob: b2b03fc97d461b1943f23cc8c5d2b7396f044613 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import chainer
from chainer import reporter

from .config import LossConfig
from .model import Aligner
from .model import Predictor


class Loss(chainer.link.Chain):
    def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner = None):
        super().__init__()
        self.config = config

        with self.init_scope():
            self.predictor = predictor
            self.aligner = aligner

    def __call__(self, input, target, mask):
        input = chainer.as_variable(input)
        target = chainer.as_variable(target)
        mask = chainer.as_variable(mask)

        h = input
        if self.aligner is not None:
            h = self.aligner(h)
        y = self.predictor(h)

        loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask)
        loss = loss / chainer.functions.sum(mask)
        reporter.report({'loss': loss}, self)

        return loss * self.config.l1