summaryrefslogtreecommitdiff
path: root/become_yukarin/updater
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <hihokaruta@gmail.com>2018-01-14 07:40:07 +0900
committerHiroshiba Kazuyuki <hihokaruta@gmail.com>2018-01-14 07:40:07 +0900
commit2be3f03adc5695f82c6ab86da780108f786ed014 (patch)
treeae4b95aa3e45706598e66cc00ff5ad9f00ef97a9 /become_yukarin/updater
parentf9185301a22f1632b16dd5266197bb40cb7c302e (diff)
超解像
Diffstat (limited to 'become_yukarin/updater')
-rw-r--r--become_yukarin/updater/__init__.py2
-rw-r--r--become_yukarin/updater/sr_updater.py69
-rw-r--r--become_yukarin/updater/updater.py103
3 files changed, 174 insertions, 0 deletions
diff --git a/become_yukarin/updater/__init__.py b/become_yukarin/updater/__init__.py
new file mode 100644
index 0000000..d85003a
--- /dev/null
+++ b/become_yukarin/updater/__init__.py
@@ -0,0 +1,2 @@
+from . import sr_updater
+from . import updater
diff --git a/become_yukarin/updater/sr_updater.py b/become_yukarin/updater/sr_updater.py
new file mode 100644
index 0000000..a6b1d22
--- /dev/null
+++ b/become_yukarin/updater/sr_updater.py
@@ -0,0 +1,69 @@
+import chainer
+import chainer.functions as F
+from become_yukarin.config.sr_config import SRLossConfig
+
+from become_yukarin.model.sr_model import SRDiscriminator
+from become_yukarin.model.sr_model import SRPredictor
+
+
+class SRUpdater(chainer.training.StandardUpdater):
+ def __init__(
+ self,
+ loss_config: SRLossConfig,
+ predictor: SRPredictor,
+ discriminator: SRDiscriminator,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.loss_config = loss_config
+ self.predictor = predictor
+ self.discriminator = discriminator
+
+ def _loss_predictor(self, predictor, output, target, d_fake):
+ b, _, w, h = d_fake.data.shape
+
+ loss_mse = (F.mean_absolute_error(output, target))
+ chainer.report({'mse': loss_mse}, predictor)
+
+ loss_adv = F.sum(F.softplus(-d_fake)) / (b * w * h)
+ chainer.report({'adversarial': loss_adv}, predictor)
+
+ loss = self.loss_config.mse * loss_mse + self.loss_config.adversarial * loss_adv
+ chainer.report({'loss': loss}, predictor)
+ return loss
+
+ def _loss_discriminator(self, discriminator, y_in, y_out):
+ b, _, w, h = y_in.data.shape
+
+ loss_real = F.sum(F.softplus(-y_in)) / (b * w * h)
+ chainer.report({'real': loss_real}, discriminator)
+
+ loss_fake = F.sum(F.softplus(y_out)) / (b * w * h)
+ chainer.report({'fake': loss_fake}, discriminator)
+
+ loss = loss_real + loss_fake
+ chainer.report({'loss': loss}, discriminator)
+ return loss
+
+ def forward(self, input, target):
+ output = self.predictor(input)
+ d_fake = self.discriminator(input, output)
+ d_real = self.discriminator(input, target)
+
+ loss = {
+ 'predictor': self._loss_predictor(self.predictor, output, target, d_fake),
+ 'discriminator': self._loss_discriminator(self.discriminator, d_real, d_fake),
+ }
+ return loss
+
+ def update_core(self):
+ opt_predictor = self.get_optimizer('predictor')
+ opt_discriminator = self.get_optimizer('discriminator')
+
+ batch = self.get_iterator('main').next()
+ batch = self.converter(batch, self.device)
+ loss = self.forward(**batch)
+
+ opt_predictor.update(loss.get, 'predictor')
+ opt_discriminator.update(loss.get, 'discriminator')
diff --git a/become_yukarin/updater/updater.py b/become_yukarin/updater/updater.py
new file mode 100644
index 0000000..ef77e77
--- /dev/null
+++ b/become_yukarin/updater/updater.py
@@ -0,0 +1,103 @@
+import chainer
+import numpy
+from chainer import reporter
+
+from become_yukarin.config.config import LossConfig
+from become_yukarin.model.model import Aligner
+from become_yukarin.model.model import Discriminator
+from become_yukarin.model.model import Predictor
+
+
+class Updater(chainer.training.StandardUpdater):
+ def __init__(
+ self,
+ loss_config: LossConfig,
+ predictor: Predictor,
+ aligner: Aligner = None,
+ discriminator: Discriminator = None,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.loss_config = loss_config
+ self.predictor = predictor
+ self.aligner = aligner
+ self.discriminator = discriminator
+
+ def forward(self, input, target, mask):
+ xp = self.predictor.xp
+
+ input = chainer.as_variable(input)
+ target = chainer.as_variable(target)
+ mask = chainer.as_variable(mask)
+
+ if self.aligner is not None:
+ input = self.aligner(input)
+ y = self.predictor(input)
+
+ loss_l1 = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask)
+ loss_l1 = loss_l1 / chainer.functions.sum(mask)
+ reporter.report({'l1': loss_l1}, self.predictor)
+
+ if self.discriminator is not None:
+ pair_fake = chainer.functions.concat([y * mask, input])
+ pair_true = chainer.functions.concat([target * mask, input])
+
+ # DRAGAN
+ if chainer.config.train: # grad is not available on test
+ std = xp.std(pair_true.data, axis=0, keepdims=True)
+ rand = xp.random.uniform(0, 1, pair_true.shape).astype(xp.float32)
+ perturb = chainer.Variable(pair_true.data + 0.5 * rand * std)
+ grad, = chainer.grad([self.discriminator(perturb)], [perturb], enable_double_backprop=True)
+ grad = chainer.functions.sqrt(chainer.functions.batch_l2_norm_squared(grad))
+ loss_grad = chainer.functions.mean_squared_error(grad, xp.ones_like(grad.data, numpy.float32))
+ reporter.report({'grad': loss_grad}, self.discriminator)
+
+ if xp.any(xp.isnan(loss_grad.data)):
+ import code
+ code.interact(local=locals())
+
+ # GAN
+ d_fake = self.discriminator(pair_fake)
+ d_true = self.discriminator(pair_true)
+ loss_dis_f = chainer.functions.average(chainer.functions.softplus(d_fake))
+ loss_dis_t = chainer.functions.average(chainer.functions.softplus(-d_true))
+ loss_gen_f = chainer.functions.average(chainer.functions.softplus(-d_fake))
+ reporter.report({'fake': loss_dis_f}, self.discriminator)
+ reporter.report({'true': loss_dis_t}, self.discriminator)
+
+ tp = (d_true.data > 0.5).sum()
+ fp = (d_fake.data > 0.5).sum()
+ fn = (d_true.data <= 0.5).sum()
+ tn = (d_fake.data <= 0.5).sum()
+ accuracy = (tp + tn) / (tp + fp + fn + tn)
+ precision = tp / (tp + fp)
+ recall = tp / (tp + fn)
+ reporter.report({'accuracy': accuracy}, self.discriminator)
+ reporter.report({'precision': precision}, self.discriminator)
+ reporter.report({'recall': recall}, self.discriminator)
+
+ loss = {'predictor': loss_l1 * self.loss_config.l1}
+
+ if self.aligner is not None:
+ loss['aligner'] = loss_l1 * self.loss_config.l1
+ reporter.report({'loss': loss['aligner']}, self.aligner)
+
+ if self.discriminator is not None:
+ loss['discriminator'] = \
+ loss_dis_f * self.loss_config.discriminator_fake + \
+ loss_dis_t * self.loss_config.discriminator_true
+ if chainer.config.train: # grad is not available on test
+ loss['discriminator'] += loss_grad * self.loss_config.discriminator_grad
+ reporter.report({'loss': loss['discriminator']}, self.discriminator)
+ loss['predictor'] += loss_gen_f * self.loss_config.predictor_fake
+
+ reporter.report({'loss': loss['predictor']}, self.predictor)
+ return loss
+
+ def update_core(self):
+ batch = self.get_iterator('main').next()
+ loss = self.forward(**self.converter(batch, self.device))
+
+ for k, opt in self.get_all_optimizers().items():
+ opt.update(loss.get, k)