diff options
Diffstat (limited to 'become_yukarin/updater/sr_updater.py')
| -rw-r--r-- | become_yukarin/updater/sr_updater.py | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/become_yukarin/updater/sr_updater.py b/become_yukarin/updater/sr_updater.py index a6b1d22..6e2b400 100644 --- a/become_yukarin/updater/sr_updater.py +++ b/become_yukarin/updater/sr_updater.py @@ -33,17 +33,28 @@ class SRUpdater(chainer.training.StandardUpdater): chainer.report({'loss': loss}, predictor) return loss - def _loss_discriminator(self, discriminator, y_in, y_out): - b, _, w, h = y_in.data.shape + def _loss_discriminator(self, discriminator, d_real, d_fake): + b, _, w, h = d_real.data.shape - loss_real = F.sum(F.softplus(-y_in)) / (b * w * h) + loss_real = F.sum(F.softplus(-d_real)) / (b * w * h) chainer.report({'real': loss_real}, discriminator) - loss_fake = F.sum(F.softplus(y_out)) / (b * w * h) + loss_fake = F.sum(F.softplus(d_fake)) / (b * w * h) chainer.report({'fake': loss_fake}, discriminator) loss = loss_real + loss_fake chainer.report({'loss': loss}, discriminator) + + tp = (d_real.data > 0.5).sum() + fp = (d_fake.data > 0.5).sum() + fn = (d_real.data <= 0.5).sum() + tn = (d_fake.data <= 0.5).sum() + accuracy = (tp + tn) / (tp + fp + fn + tn) + precision = tp / (tp + fp) + recall = tp / (tp + fn) + chainer.report({'accuracy': accuracy}, self.discriminator) + chainer.report({'precision': precision}, self.discriminator) + chainer.report({'recall': recall}, self.discriminator) return loss def forward(self, input, target): |
