summaryrefslogtreecommitdiff
path: root/become_yukarin/updater.py
blob: 927601fc7a4047628fe45960a184a079acdbc2b9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import chainer
import numpy
from chainer import reporter

from .config import LossConfig
from .config import ModelConfig
from .model import Aligner
from .model import Discriminator
from .model import Predictor


class Updater(chainer.training.StandardUpdater):
    def __init__(
            self,
            loss_config: LossConfig,
            model_config: ModelConfig,
            predictor: Predictor,
            aligner: Aligner = None,
            discriminator: Discriminator = None,
            *args,
            **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.loss_config = loss_config
        self.model_config = model_config
        self.predictor = predictor
        self.aligner = aligner
        self.discriminator = discriminator

    def forward(self, input, target, mask):
        xp = self.predictor.xp

        input = chainer.as_variable(input)
        target = chainer.as_variable(target)
        mask = chainer.as_variable(mask)

        if self.aligner is not None:
            input = self.aligner(input)
        y = self.predictor(input)

        loss_l1 = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask)
        loss_l1 = loss_l1 / chainer.functions.sum(mask)
        reporter.report({'l1': loss_l1}, self.predictor)

        if self.discriminator is not None:
            pair_fake = chainer.functions.concat([y * mask, input])
            pair_true = chainer.functions.concat([target * mask, input])
            d_fake = self.discriminator(pair_fake)
            d_true = self.discriminator(pair_true)
            loss_dis_f = chainer.functions.mean_squared_error(d_fake, xp.zeros_like(d_fake.data, numpy.float32))
            loss_dis_t = chainer.functions.mean_squared_error(d_true, xp.ones_like(d_true.data, numpy.float32))
            loss_gen_f = chainer.functions.mean_squared_error(d_fake, xp.ones_like(d_fake.data, numpy.float32))
            reporter.report({'fake': loss_dis_f}, self.discriminator)
            reporter.report({'true': loss_dis_t}, self.discriminator)

        loss = {'predictor': loss_l1 * self.loss_config.l1}

        if self.aligner is not None:
            loss['aligner'] = loss_l1 * self.loss_config.l1
            reporter.report({'loss': loss['aligner']}, self.aligner)

        if self.discriminator is not None:
            loss['discriminator'] = \
                loss_dis_f * self.loss_config.discriminator_fake + \
                loss_dis_t * self.loss_config.discriminator_true
            reporter.report({'loss': loss['discriminator']}, self.discriminator)
            loss['predictor'] += loss_gen_f * self.loss_config.predictor_fake

        reporter.report({'loss': loss['predictor']}, self.predictor)
        return loss

    def update_core(self):
        batch = self.get_iterator('main').next()
        loss = self.forward(**self.converter(batch, self.device))

        for k, opt in self.get_all_optimizers().items():
            opt.update(loss.get, k)