summaryrefslogtreecommitdiff
path: root/become_yukarin/updater.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/updater.py')
-rw-r--r--become_yukarin/updater.py24
1 files changed, 21 insertions, 3 deletions
diff --git a/become_yukarin/updater.py b/become_yukarin/updater.py
index 927601f..02ea5d3 100644
--- a/become_yukarin/updater.py
+++ b/become_yukarin/updater.py
@@ -45,11 +45,27 @@ class Updater(chainer.training.StandardUpdater):
if self.discriminator is not None:
pair_fake = chainer.functions.concat([y * mask, input])
pair_true = chainer.functions.concat([target * mask, input])
+
+ # DRAGAN
+ if chainer.config.train: # grad is not available on test
+ std = xp.std(pair_true.data, axis=0, keepdims=True)
+ rand = xp.random.uniform(0, 1, pair_true.shape).astype(xp.float32)
+ perturb = chainer.Variable(pair_true.data + 0.5 * rand * std)
+ grad, = chainer.grad([self.discriminator(perturb)], [perturb], enable_double_backprop=True)
+ grad = chainer.functions.sqrt(chainer.functions.batch_l2_norm_squared(grad))
+ loss_grad = chainer.functions.mean_squared_error(grad, xp.ones_like(grad.data, numpy.float32))
+ reporter.report({'grad': loss_grad}, self.discriminator)
+
+ if xp.any(xp.isnan(loss_grad.data)):
+ import code
+ code.interact(local=locals())
+
+ # GAN
d_fake = self.discriminator(pair_fake)
d_true = self.discriminator(pair_true)
- loss_dis_f = chainer.functions.mean_squared_error(d_fake, xp.zeros_like(d_fake.data, numpy.float32))
- loss_dis_t = chainer.functions.mean_squared_error(d_true, xp.ones_like(d_true.data, numpy.float32))
- loss_gen_f = chainer.functions.mean_squared_error(d_fake, xp.ones_like(d_fake.data, numpy.float32))
+ loss_dis_f = chainer.functions.average(chainer.functions.softplus(d_fake))
+ loss_dis_t = chainer.functions.average(chainer.functions.softplus(-d_true))
+ loss_gen_f = chainer.functions.average(chainer.functions.softplus(-d_fake))
reporter.report({'fake': loss_dis_f}, self.discriminator)
reporter.report({'true': loss_dis_t}, self.discriminator)
@@ -63,6 +79,8 @@ class Updater(chainer.training.StandardUpdater):
loss['discriminator'] = \
loss_dis_f * self.loss_config.discriminator_fake + \
loss_dis_t * self.loss_config.discriminator_true
+ if chainer.config.train: # grad is not available on test
+ loss['discriminator'] += loss_grad * self.loss_config.discriminator_grad
reporter.report({'loss': loss['discriminator']}, self.discriminator)
loss['predictor'] += loss_gen_f * self.loss_config.predictor_fake