summaryrefslogtreecommitdiff
path: root/become_yukarin/updater.py
blob: f6444d058f5f9da41a5cda035ee9d7d1b5068cd2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import chainer
import numpy
from chainer import reporter

from .config import LossConfig
from .config import ModelConfig
from .model import Aligner
from .model import Discriminator
from .model import Predictor


class Updater(chainer.training.StandardUpdater):
    def __init__(
            self,
            loss_config: LossConfig,
            model_config: ModelConfig,
            predictor: Predictor,
            aligner: Aligner = None,
            discriminator: Discriminator = None,
            *args,
            **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.loss_config = loss_config
        self.model_config = model_config
        self.predictor = predictor
        self.aligner = aligner
        self.discriminator = discriminator

    def forward(self, input, target, mask):
        xp = self.predictor.xp

        input = chainer.as_variable(input)
        target = chainer.as_variable(target)
        mask = chainer.as_variable(mask)

        if self.aligner is not None:
            input = self.aligner(input)
        y = self.predictor(input)

        loss_l1 = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask)
        loss_l1 = loss_l1 / chainer.functions.sum(mask)
        reporter.report({'l1': loss_l1}, self.predictor)

        if self.discriminator is not None:
            pair_fake = chainer.functions.concat([y * mask, input])
            pair_true = chainer.functions.concat([target * mask, input])

            # DRAGAN
            if chainer.config.train:  # grad is not available on test
                std = xp.std(pair_true.data, axis=0, keepdims=True)
                rand = xp.random.uniform(0, 1, pair_true.shape).astype(xp.float32)
                perturb = chainer.Variable(pair_true.data + 0.5 * rand * std)
                grad, = chainer.grad([self.discriminator(perturb)], [perturb], enable_double_backprop=True)
                grad = chainer.functions.sqrt(chainer.functions.batch_l2_norm_squared(grad))
                loss_grad = chainer.functions.mean_squared_error(grad, xp.ones_like(grad.data, numpy.float32))
                reporter.report({'grad': loss_grad}, self.discriminator)

                if xp.any(xp.isnan(loss_grad.data)):
                    import code
                    code.interact(local=locals())

            # GAN
            d_fake = self.discriminator(pair_fake)
            d_true = self.discriminator(pair_true)
            loss_dis_f = chainer.functions.average(chainer.functions.softplus(d_fake))
            loss_dis_t = chainer.functions.average(chainer.functions.softplus(-d_true))
            loss_gen_f = chainer.functions.average(chainer.functions.softplus(-d_fake))
            reporter.report({'fake': loss_dis_f}, self.discriminator)
            reporter.report({'true': loss_dis_t}, self.discriminator)

            tp = (d_true.data > 0.5).sum()
            fp = (d_fake.data > 0.5).sum()
            fn = (d_true.data <= 0.5).sum()
            tn = (d_fake.data <= 0.5).sum()
            accuracy = (tp + tn) / (tp + fp + fn + tn)
            precision = tp / (tp + fp)
            recall = tp / (tp + fn)
            reporter.report({'accuracy': accuracy}, self.discriminator)
            reporter.report({'precision': precision}, self.discriminator)
            reporter.report({'recall': recall}, self.discriminator)

        loss = {'predictor': loss_l1 * self.loss_config.l1}

        if self.aligner is not None:
            loss['aligner'] = loss_l1 * self.loss_config.l1
            reporter.report({'loss': loss['aligner']}, self.aligner)

        if self.discriminator is not None:
            loss['discriminator'] = \
                loss_dis_f * self.loss_config.discriminator_fake + \
                loss_dis_t * self.loss_config.discriminator_true
            if chainer.config.train:  # grad is not available on test
                loss['discriminator'] += loss_grad * self.loss_config.discriminator_grad
            reporter.report({'loss': loss['discriminator']}, self.discriminator)
            loss['predictor'] += loss_gen_f * self.loss_config.predictor_fake

        reporter.report({'loss': loss['predictor']}, self.predictor)
        return loss

    def update_core(self):
        batch = self.get_iterator('main').next()
        loss = self.forward(**self.converter(batch, self.device))

        for k, opt in self.get_all_optimizers().items():
            opt.update(loss.get, k)