1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
|
import chainer
import numpy
from chainer import reporter
from become_yukarin.config.config import LossConfig
from become_yukarin.model.model import Aligner
from become_yukarin.model.model import Discriminator
from become_yukarin.model.model import Predictor
class Updater(chainer.training.StandardUpdater):
def __init__(
self,
loss_config: LossConfig,
predictor: Predictor,
aligner: Aligner = None,
discriminator: Discriminator = None,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.loss_config = loss_config
self.predictor = predictor
self.aligner = aligner
self.discriminator = discriminator
def forward(self, input, target, mask):
xp = self.predictor.xp
input = chainer.as_variable(input)
target = chainer.as_variable(target)
mask = chainer.as_variable(mask)
if self.aligner is not None:
input = self.aligner(input)
y = self.predictor(input)
loss_l1 = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask)
loss_l1 = loss_l1 / chainer.functions.sum(mask)
reporter.report({'l1': loss_l1}, self.predictor)
if self.discriminator is not None:
pair_fake = chainer.functions.concat([y * mask, input])
pair_true = chainer.functions.concat([target * mask, input])
# DRAGAN
if chainer.config.train: # grad is not available on test
std = xp.std(pair_true.data, axis=0, keepdims=True)
rand = xp.random.uniform(0, 1, pair_true.shape).astype(xp.float32)
perturb = chainer.Variable(pair_true.data + 0.5 * rand * std)
grad, = chainer.grad([self.discriminator(perturb)], [perturb], enable_double_backprop=True)
grad = chainer.functions.sqrt(chainer.functions.batch_l2_norm_squared(grad))
loss_grad = chainer.functions.mean_squared_error(grad, xp.ones_like(grad.data, numpy.float32))
reporter.report({'grad': loss_grad}, self.discriminator)
if xp.any(xp.isnan(loss_grad.data)):
import code
code.interact(local=locals())
# GAN
d_fake = self.discriminator(pair_fake)
d_true = self.discriminator(pair_true)
loss_dis_f = chainer.functions.average(chainer.functions.softplus(d_fake))
loss_dis_t = chainer.functions.average(chainer.functions.softplus(-d_true))
loss_gen_f = chainer.functions.average(chainer.functions.softplus(-d_fake))
reporter.report({'fake': loss_dis_f}, self.discriminator)
reporter.report({'true': loss_dis_t}, self.discriminator)
tp = (d_true.data > 0.5).sum()
fp = (d_fake.data > 0.5).sum()
fn = (d_true.data <= 0.5).sum()
tn = (d_fake.data <= 0.5).sum()
accuracy = (tp + tn) / (tp + fp + fn + tn)
precision = tp / (tp + fp)
recall = tp / (tp + fn)
reporter.report({'accuracy': accuracy}, self.discriminator)
reporter.report({'precision': precision}, self.discriminator)
reporter.report({'recall': recall}, self.discriminator)
loss = {'predictor': loss_l1 * self.loss_config.l1}
if self.aligner is not None:
loss['aligner'] = loss_l1 * self.loss_config.l1
reporter.report({'loss': loss['aligner']}, self.aligner)
if self.discriminator is not None:
loss['discriminator'] = \
loss_dis_f * self.loss_config.discriminator_fake + \
loss_dis_t * self.loss_config.discriminator_true
if chainer.config.train: # grad is not available on test
loss['discriminator'] += loss_grad * self.loss_config.discriminator_grad
reporter.report({'loss': loss['discriminator']}, self.discriminator)
loss['predictor'] += loss_gen_f * self.loss_config.predictor_fake
reporter.report({'loss': loss['predictor']}, self.predictor)
return loss
def update_core(self):
batch = self.get_iterator('main').next()
loss = self.forward(**self.converter(batch, self.device))
for k, opt in self.get_all_optimizers().items():
opt.update(loss.get, k)
|