summaryrefslogtreecommitdiff
path: root/become_yukarin/updater/sr_updater.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/updater/sr_updater.py')
-rw-r--r--become_yukarin/updater/sr_updater.py19
1 files changed, 15 insertions, 4 deletions
diff --git a/become_yukarin/updater/sr_updater.py b/become_yukarin/updater/sr_updater.py
index a6b1d22..6e2b400 100644
--- a/become_yukarin/updater/sr_updater.py
+++ b/become_yukarin/updater/sr_updater.py
@@ -33,17 +33,28 @@ class SRUpdater(chainer.training.StandardUpdater):
chainer.report({'loss': loss}, predictor)
return loss
- def _loss_discriminator(self, discriminator, y_in, y_out):
- b, _, w, h = y_in.data.shape
+ def _loss_discriminator(self, discriminator, d_real, d_fake):
+ b, _, w, h = d_real.data.shape
- loss_real = F.sum(F.softplus(-y_in)) / (b * w * h)
+ loss_real = F.sum(F.softplus(-d_real)) / (b * w * h)
chainer.report({'real': loss_real}, discriminator)
- loss_fake = F.sum(F.softplus(y_out)) / (b * w * h)
+ loss_fake = F.sum(F.softplus(d_fake)) / (b * w * h)
chainer.report({'fake': loss_fake}, discriminator)
loss = loss_real + loss_fake
chainer.report({'loss': loss}, discriminator)
+
+ tp = (d_real.data > 0.5).sum()
+ fp = (d_fake.data > 0.5).sum()
+ fn = (d_real.data <= 0.5).sum()
+ tn = (d_fake.data <= 0.5).sum()
+ accuracy = (tp + tn) / (tp + fp + fn + tn)
+ precision = tp / (tp + fp)
+ recall = tp / (tp + fn)
+ chainer.report({'accuracy': accuracy}, self.discriminator)
+ chainer.report({'precision': precision}, self.discriminator)
+ chainer.report({'recall': recall}, self.discriminator)
return loss
def forward(self, input, target):