summaryrefslogtreecommitdiff
path: root/become_yukarin/updater/updater.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/updater/updater.py')
-rw-r--r--become_yukarin/updater/updater.py116
1 files changed, 50 insertions, 66 deletions
diff --git a/become_yukarin/updater/updater.py b/become_yukarin/updater/updater.py
index 8dcb215..eb51068 100644
--- a/become_yukarin/updater/updater.py
+++ b/become_yukarin/updater/updater.py
@@ -1,9 +1,7 @@
import chainer
-import numpy
-from chainer import reporter
+import chainer.functions as F
from become_yukarin.config.config import LossConfig
-from become_yukarin.model.model import Aligner
from become_yukarin.model.model import Discriminator
from become_yukarin.model.model import Predictor
@@ -13,91 +11,77 @@ class Updater(chainer.training.StandardUpdater):
self,
loss_config: LossConfig,
predictor: Predictor,
- aligner: Aligner = None,
- discriminator: Discriminator = None,
+ discriminator: Discriminator,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.loss_config = loss_config
self.predictor = predictor
- self.aligner = aligner
self.discriminator = discriminator
- def forward(self, input, target, mask):
- xp = self.predictor.xp
+ def _loss_predictor(self, predictor, output, target, d_fake):
+ b, _, t = d_fake.data.shape
- input = chainer.as_variable(input)
- target = chainer.as_variable(target)
- mask = chainer.as_variable(mask)
+ loss_mse = (F.mean_absolute_error(output, target))
+ chainer.report({'mse': loss_mse}, predictor)
- if self.aligner is not None:
- input = self.aligner(input)
- y = self.predictor(input)
+ loss_adv = F.sum(F.softplus(-d_fake)) / (b * t)
+ chainer.report({'adversarial': loss_adv}, predictor)
- loss_l1 = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask)
- loss_l1 = loss_l1 / chainer.functions.sum(mask)
- reporter.report({'l1': loss_l1}, self.predictor)
+ loss = self.loss_config.mse * loss_mse + self.loss_config.adversarial * loss_adv
+ chainer.report({'loss': loss}, predictor)
+ return loss
- if self.discriminator is not None:
- pair_fake = chainer.functions.concat([y * mask, input])
- pair_true = chainer.functions.concat([target * mask, input])
+ def _loss_discriminator(self, discriminator, d_real, d_fake):
+ b, _, t = d_real.data.shape
- # DRAGAN
- if chainer.config.train: # grad is not available on test
- std = xp.std(pair_true.data, axis=0, keepdims=True)
- rand = xp.random.uniform(0, 1, pair_true.shape).astype(xp.float32)
- perturb = chainer.Variable(pair_true.data + 0.5 * rand * std)
- grad, = chainer.grad([self.discriminator(perturb)], [perturb], enable_double_backprop=True)
- grad = chainer.functions.sqrt(chainer.functions.batch_l2_norm_squared(grad))
- loss_grad = chainer.functions.mean_squared_error(grad, xp.ones_like(grad.data, numpy.float32))
- reporter.report({'grad': loss_grad}, self.discriminator)
+ loss_real = F.sum(F.softplus(-d_real)) / (b * t)
+ chainer.report({'real': loss_real}, discriminator)
- if xp.any(xp.isnan(loss_grad.data)):
- import code
- code.interact(local=locals())
+ loss_fake = F.sum(F.softplus(d_fake)) / (b * t)
+ chainer.report({'fake': loss_fake}, discriminator)
- # GAN
- d_fake = self.discriminator(pair_fake)
- d_true = self.discriminator(pair_true)
- loss_dis_f = chainer.functions.average(chainer.functions.softplus(d_fake))
- loss_dis_t = chainer.functions.average(chainer.functions.softplus(-d_true))
- loss_gen_f = chainer.functions.average(chainer.functions.softplus(-d_fake))
- reporter.report({'fake': loss_dis_f}, self.discriminator)
- reporter.report({'true': loss_dis_t}, self.discriminator)
+ loss = loss_real + loss_fake
+ chainer.report({'loss': loss}, discriminator)
- tp = (d_true.data > 0.5).sum()
- fp = (d_fake.data > 0.5).sum()
- fn = (d_true.data <= 0.5).sum()
- tn = (d_fake.data <= 0.5).sum()
- accuracy = (tp + tn) / (tp + fp + fn + tn)
- precision = tp / (tp + fp)
- recall = tp / (tp + fn)
- reporter.report({'accuracy': accuracy}, self.discriminator)
- reporter.report({'precision': precision}, self.discriminator)
- reporter.report({'recall': recall}, self.discriminator)
+ tp = (d_real.data > 0.5).sum()
+ fp = (d_fake.data > 0.5).sum()
+ fn = (d_real.data <= 0.5).sum()
+ tn = (d_fake.data <= 0.5).sum()
+ accuracy = (tp + tn) / (tp + fp + fn + tn)
+ precision = tp / (tp + fp)
+ recall = tp / (tp + fn)
+ chainer.report({'accuracy': accuracy}, self.discriminator)
+ chainer.report({'precision': precision}, self.discriminator)
+ chainer.report({'recall': recall}, self.discriminator)
+ return loss
- loss = {'predictor': loss_l1 * self.loss_config.l1}
+ def forward(self, input, target, mask):
+ input = chainer.as_variable(input)
+ target = chainer.as_variable(target)
+ mask = chainer.as_variable(mask)
- if self.aligner is not None:
- loss['aligner'] = loss_l1 * self.loss_config.l1
- reporter.report({'loss': loss['aligner']}, self.aligner)
+ output = self.predictor(input)
+ output = output * mask
+ target = target * mask
- if self.discriminator is not None:
- loss['discriminator'] = \
- loss_dis_f * self.loss_config.discriminator_fake + \
- loss_dis_t * self.loss_config.discriminator_true
- if chainer.config.train: # grad is not available on test
- loss['discriminator'] += loss_grad * self.loss_config.discriminator_grad
- reporter.report({'loss': loss['discriminator']}, self.discriminator)
- loss['predictor'] += loss_gen_f * self.loss_config.predictor_fake
+ d_fake = self.discriminator(input, output)
+ d_real = self.discriminator(input, target)
- reporter.report({'loss': loss['predictor']}, self.predictor)
+ loss = {
+ 'predictor': self._loss_predictor(self.predictor, output, target, d_fake),
+ 'discriminator': self._loss_discriminator(self.discriminator, d_real, d_fake),
+ }
return loss
def update_core(self):
+ opt_predictor = self.get_optimizer('predictor')
+ opt_discriminator = self.get_optimizer('discriminator')
+
batch = self.get_iterator('main').next()
- loss = self.forward(**self.converter(batch, self.device))
+ batch = self.converter(batch, self.device)
+ loss = self.forward(**batch)
- for k, opt in self.get_all_optimizers().items():
- opt.update(loss.get, k)
+ opt_predictor.update(loss.get, 'predictor')
+ opt_discriminator.update(loss.get, 'discriminator')