diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2018-01-09 18:59:57 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2018-01-09 18:59:57 +0900 |
| commit | 6b56d407a050c9c24fc0a0f6702bf5e9eee7450f (patch) | |
| tree | 2e3ea51638586fbb2094eebd6b010e61a52cd5ce | |
| parent | 12cb80fb45d0f19c5d98ee60cda346ad324d1377 (diff) | |
discriminator accuracy
| -rw-r--r-- | become_yukarin/dataset/dataset.py | 4 | ||||
| -rw-r--r-- | become_yukarin/updater.py | 11 | ||||
| -rw-r--r-- | train.py | 1 |
3 files changed, 14 insertions, 2 deletions
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py index d259734..368073d 100644 --- a/become_yukarin/dataset/dataset.py +++ b/become_yukarin/dataset/dataset.py @@ -189,7 +189,7 @@ class AcousticFeatureNormalizeProcess(BaseDataProcess): self._mean = mean self._var = var - def __call__(self, data: AcousticFeature, test): + def __call__(self, data: AcousticFeature, test=None): f0 = (data.f0 - self._mean.f0) / numpy.sqrt(self._var.f0) f0[~data.voiced] = 0 return AcousticFeature( @@ -206,7 +206,7 @@ class AcousticFeatureDenormalizeProcess(BaseDataProcess): self._mean = mean self._var = var - def __call__(self, data: AcousticFeature, test): + def __call__(self, data: AcousticFeature, test=None): f0 = data.f0 * numpy.sqrt(self._var.f0) + self._mean.f0 f0[~data.voiced] = 0 return AcousticFeature( diff --git a/become_yukarin/updater.py b/become_yukarin/updater.py index 02ea5d3..f6444d0 100644 --- a/become_yukarin/updater.py +++ b/become_yukarin/updater.py @@ -69,6 +69,17 @@ class Updater(chainer.training.StandardUpdater): reporter.report({'fake': loss_dis_f}, self.discriminator) reporter.report({'true': loss_dis_t}, self.discriminator) + tp = (d_true.data > 0.5).sum() + fp = (d_fake.data > 0.5).sum() + fn = (d_true.data <= 0.5).sum() + tn = (d_fake.data <= 0.5).sum() + accuracy = (tp + tn) / (tp + fp + fn + tn) + precision = tp / (tp + fp) + recall = tp / (tp + fn) + reporter.report({'accuracy': accuracy}, self.discriminator) + reporter.report({'precision': precision}, self.discriminator) + reporter.report({'recall': recall}, self.discriminator) + loss = {'predictor': loss_l1 * self.loss_config.l1} if self.aligner is not None: @@ -88,6 +88,7 @@ if extensions.PlotReport.available(): 'predictor/l1', 'test/predictor/loss', 'train/predictor/loss', + 'discriminator/accuracy', 'discriminator/fake', 'discriminator/true', 'discriminator/grad', |
