diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2018-02-02 16:59:33 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2018-02-10 20:52:52 +0900 |
| commit | 0b410c0bbacc147950438423835547b74859aac1 (patch) | |
| tree | a5f8dc39e4d980bcfb09b21f47bf715d860c2e45 /become_yukarin/updater/updater.py | |
| parent | 4ff5252dbdc0cdaeecc7fbe399c629e4d29de3a3 (diff) | |
pix2pix convertモデル
Diffstat (limited to 'become_yukarin/updater/updater.py')
| -rw-r--r-- | become_yukarin/updater/updater.py | 116 |
1 files changed, 50 insertions, 66 deletions
diff --git a/become_yukarin/updater/updater.py b/become_yukarin/updater/updater.py index 8dcb215..eb51068 100644 --- a/become_yukarin/updater/updater.py +++ b/become_yukarin/updater/updater.py @@ -1,9 +1,7 @@ import chainer -import numpy -from chainer import reporter +import chainer.functions as F from become_yukarin.config.config import LossConfig -from become_yukarin.model.model import Aligner from become_yukarin.model.model import Discriminator from become_yukarin.model.model import Predictor @@ -13,91 +11,77 @@ class Updater(chainer.training.StandardUpdater): self, loss_config: LossConfig, predictor: Predictor, - aligner: Aligner = None, - discriminator: Discriminator = None, + discriminator: Discriminator, *args, **kwargs, ) -> None: super().__init__(*args, **kwargs) self.loss_config = loss_config self.predictor = predictor - self.aligner = aligner self.discriminator = discriminator - def forward(self, input, target, mask): - xp = self.predictor.xp + def _loss_predictor(self, predictor, output, target, d_fake): + b, _, t = d_fake.data.shape - input = chainer.as_variable(input) - target = chainer.as_variable(target) - mask = chainer.as_variable(mask) + loss_mse = (F.mean_absolute_error(output, target)) + chainer.report({'mse': loss_mse}, predictor) - if self.aligner is not None: - input = self.aligner(input) - y = self.predictor(input) + loss_adv = F.sum(F.softplus(-d_fake)) / (b * t) + chainer.report({'adversarial': loss_adv}, predictor) - loss_l1 = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask) - loss_l1 = loss_l1 / chainer.functions.sum(mask) - reporter.report({'l1': loss_l1}, self.predictor) + loss = self.loss_config.mse * loss_mse + self.loss_config.adversarial * loss_adv + chainer.report({'loss': loss}, predictor) + return loss - if self.discriminator is not None: - pair_fake = chainer.functions.concat([y * mask, input]) - pair_true = chainer.functions.concat([target * mask, input]) + def _loss_discriminator(self, discriminator, d_real, d_fake): + b, _, t = d_real.data.shape - # DRAGAN - if chainer.config.train: # grad is not available on test - std = xp.std(pair_true.data, axis=0, keepdims=True) - rand = xp.random.uniform(0, 1, pair_true.shape).astype(xp.float32) - perturb = chainer.Variable(pair_true.data + 0.5 * rand * std) - grad, = chainer.grad([self.discriminator(perturb)], [perturb], enable_double_backprop=True) - grad = chainer.functions.sqrt(chainer.functions.batch_l2_norm_squared(grad)) - loss_grad = chainer.functions.mean_squared_error(grad, xp.ones_like(grad.data, numpy.float32)) - reporter.report({'grad': loss_grad}, self.discriminator) + loss_real = F.sum(F.softplus(-d_real)) / (b * t) + chainer.report({'real': loss_real}, discriminator) - if xp.any(xp.isnan(loss_grad.data)): - import code - code.interact(local=locals()) + loss_fake = F.sum(F.softplus(d_fake)) / (b * t) + chainer.report({'fake': loss_fake}, discriminator) - # GAN - d_fake = self.discriminator(pair_fake) - d_true = self.discriminator(pair_true) - loss_dis_f = chainer.functions.average(chainer.functions.softplus(d_fake)) - loss_dis_t = chainer.functions.average(chainer.functions.softplus(-d_true)) - loss_gen_f = chainer.functions.average(chainer.functions.softplus(-d_fake)) - reporter.report({'fake': loss_dis_f}, self.discriminator) - reporter.report({'true': loss_dis_t}, self.discriminator) + loss = loss_real + loss_fake + chainer.report({'loss': loss}, discriminator) - tp = (d_true.data > 0.5).sum() - fp = (d_fake.data > 0.5).sum() - fn = (d_true.data <= 0.5).sum() - tn = (d_fake.data <= 0.5).sum() - accuracy = (tp + tn) / (tp + fp + fn + tn) - precision = tp / (tp + fp) - recall = tp / (tp + fn) - reporter.report({'accuracy': accuracy}, self.discriminator) - reporter.report({'precision': precision}, self.discriminator) - reporter.report({'recall': recall}, self.discriminator) + tp = (d_real.data > 0.5).sum() + fp = (d_fake.data > 0.5).sum() + fn = (d_real.data <= 0.5).sum() + tn = (d_fake.data <= 0.5).sum() + accuracy = (tp + tn) / (tp + fp + fn + tn) + precision = tp / (tp + fp) + recall = tp / (tp + fn) + chainer.report({'accuracy': accuracy}, self.discriminator) + chainer.report({'precision': precision}, self.discriminator) + chainer.report({'recall': recall}, self.discriminator) + return loss - loss = {'predictor': loss_l1 * self.loss_config.l1} + def forward(self, input, target, mask): + input = chainer.as_variable(input) + target = chainer.as_variable(target) + mask = chainer.as_variable(mask) - if self.aligner is not None: - loss['aligner'] = loss_l1 * self.loss_config.l1 - reporter.report({'loss': loss['aligner']}, self.aligner) + output = self.predictor(input) + output = output * mask + target = target * mask - if self.discriminator is not None: - loss['discriminator'] = \ - loss_dis_f * self.loss_config.discriminator_fake + \ - loss_dis_t * self.loss_config.discriminator_true - if chainer.config.train: # grad is not available on test - loss['discriminator'] += loss_grad * self.loss_config.discriminator_grad - reporter.report({'loss': loss['discriminator']}, self.discriminator) - loss['predictor'] += loss_gen_f * self.loss_config.predictor_fake + d_fake = self.discriminator(input, output) + d_real = self.discriminator(input, target) - reporter.report({'loss': loss['predictor']}, self.predictor) + loss = { + 'predictor': self._loss_predictor(self.predictor, output, target, d_fake), + 'discriminator': self._loss_discriminator(self.discriminator, d_real, d_fake), + } return loss def update_core(self): + opt_predictor = self.get_optimizer('predictor') + opt_discriminator = self.get_optimizer('discriminator') + batch = self.get_iterator('main').next() - loss = self.forward(**self.converter(batch, self.device)) + batch = self.converter(batch, self.device) + loss = self.forward(**batch) - for k, opt in self.get_all_optimizers().items(): - opt.update(loss.get, k) + opt_predictor.update(loss.get, 'predictor') + opt_discriminator.update(loss.get, 'discriminator') |
