diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2018-01-15 04:17:10 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2018-01-15 04:41:33 +0900 |
| commit | 7bfc3321e356f24f49c790b578917e8db22bd30d (patch) | |
| tree | 65e8070b34d22e8b6211bb41e7dd448eb39dccd1 /become_yukarin/updater | |
| parent | 2be3f03adc5695f82c6ab86da780108f786ed014 (diff) | |
超解像学習を可能に
Diffstat (limited to 'become_yukarin/updater')
| -rw-r--r-- | become_yukarin/updater/sr_updater.py | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/become_yukarin/updater/sr_updater.py b/become_yukarin/updater/sr_updater.py index a6b1d22..6e2b400 100644 --- a/become_yukarin/updater/sr_updater.py +++ b/become_yukarin/updater/sr_updater.py @@ -33,17 +33,28 @@ class SRUpdater(chainer.training.StandardUpdater): chainer.report({'loss': loss}, predictor) return loss - def _loss_discriminator(self, discriminator, y_in, y_out): - b, _, w, h = y_in.data.shape + def _loss_discriminator(self, discriminator, d_real, d_fake): + b, _, w, h = d_real.data.shape - loss_real = F.sum(F.softplus(-y_in)) / (b * w * h) + loss_real = F.sum(F.softplus(-d_real)) / (b * w * h) chainer.report({'real': loss_real}, discriminator) - loss_fake = F.sum(F.softplus(y_out)) / (b * w * h) + loss_fake = F.sum(F.softplus(d_fake)) / (b * w * h) chainer.report({'fake': loss_fake}, discriminator) loss = loss_real + loss_fake chainer.report({'loss': loss}, discriminator) + + tp = (d_real.data > 0.5).sum() + fp = (d_fake.data > 0.5).sum() + fn = (d_real.data <= 0.5).sum() + tn = (d_fake.data <= 0.5).sum() + accuracy = (tp + tn) / (tp + fp + fn + tn) + precision = tp / (tp + fp) + recall = tp / (tp + fn) + chainer.report({'accuracy': accuracy}, self.discriminator) + chainer.report({'precision': precision}, self.discriminator) + chainer.report({'recall': recall}, self.discriminator) return loss def forward(self, input, target): |
