From 7bfc3321e356f24f49c790b578917e8db22bd30d Mon Sep 17 00:00:00 2001 From: Hiroshiba Kazuyuki Date: Mon, 15 Jan 2018 04:17:10 +0900 Subject: 超解像学習を可能に MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- become_yukarin/updater/sr_updater.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) (limited to 'become_yukarin/updater') diff --git a/become_yukarin/updater/sr_updater.py b/become_yukarin/updater/sr_updater.py index a6b1d22..6e2b400 100644 --- a/become_yukarin/updater/sr_updater.py +++ b/become_yukarin/updater/sr_updater.py @@ -33,17 +33,28 @@ class SRUpdater(chainer.training.StandardUpdater): chainer.report({'loss': loss}, predictor) return loss - def _loss_discriminator(self, discriminator, y_in, y_out): - b, _, w, h = y_in.data.shape + def _loss_discriminator(self, discriminator, d_real, d_fake): + b, _, w, h = d_real.data.shape - loss_real = F.sum(F.softplus(-y_in)) / (b * w * h) + loss_real = F.sum(F.softplus(-d_real)) / (b * w * h) chainer.report({'real': loss_real}, discriminator) - loss_fake = F.sum(F.softplus(y_out)) / (b * w * h) + loss_fake = F.sum(F.softplus(d_fake)) / (b * w * h) chainer.report({'fake': loss_fake}, discriminator) loss = loss_real + loss_fake chainer.report({'loss': loss}, discriminator) + + tp = (d_real.data > 0.5).sum() + fp = (d_fake.data > 0.5).sum() + fn = (d_real.data <= 0.5).sum() + tn = (d_fake.data <= 0.5).sum() + accuracy = (tp + tn) / (tp + fp + fn + tn) + precision = tp / (tp + fp) + recall = tp / (tp + fn) + chainer.report({'accuracy': accuracy}, self.discriminator) + chainer.report({'precision': precision}, self.discriminator) + chainer.report({'recall': recall}, self.discriminator) return loss def forward(self, input, target): -- cgit v1.2.3-70-g09d2