summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2018-01-09 18:59:57 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2018-01-09 18:59:57 +0900
commit6b56d407a050c9c24fc0a0f6702bf5e9eee7450f (patch)
tree2e3ea51638586fbb2094eebd6b010e61a52cd5ce
parent12cb80fb45d0f19c5d98ee60cda346ad324d1377 (diff)
discriminator accuracy
-rw-r--r--become_yukarin/dataset/dataset.py4
-rw-r--r--become_yukarin/updater.py11
-rw-r--r--train.py1
3 files changed, 14 insertions, 2 deletions
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py
index d259734..368073d 100644
--- a/become_yukarin/dataset/dataset.py
+++ b/become_yukarin/dataset/dataset.py
@@ -189,7 +189,7 @@ class AcousticFeatureNormalizeProcess(BaseDataProcess):
self._mean = mean
self._var = var
- def __call__(self, data: AcousticFeature, test):
+ def __call__(self, data: AcousticFeature, test=None):
f0 = (data.f0 - self._mean.f0) / numpy.sqrt(self._var.f0)
f0[~data.voiced] = 0
return AcousticFeature(
@@ -206,7 +206,7 @@ class AcousticFeatureDenormalizeProcess(BaseDataProcess):
self._mean = mean
self._var = var
- def __call__(self, data: AcousticFeature, test):
+ def __call__(self, data: AcousticFeature, test=None):
f0 = data.f0 * numpy.sqrt(self._var.f0) + self._mean.f0
f0[~data.voiced] = 0
return AcousticFeature(
diff --git a/become_yukarin/updater.py b/become_yukarin/updater.py
index 02ea5d3..f6444d0 100644
--- a/become_yukarin/updater.py
+++ b/become_yukarin/updater.py
@@ -69,6 +69,17 @@ class Updater(chainer.training.StandardUpdater):
reporter.report({'fake': loss_dis_f}, self.discriminator)
reporter.report({'true': loss_dis_t}, self.discriminator)
+ tp = (d_true.data > 0.5).sum()
+ fp = (d_fake.data > 0.5).sum()
+ fn = (d_true.data <= 0.5).sum()
+ tn = (d_fake.data <= 0.5).sum()
+ accuracy = (tp + tn) / (tp + fp + fn + tn)
+ precision = tp / (tp + fp)
+ recall = tp / (tp + fn)
+ reporter.report({'accuracy': accuracy}, self.discriminator)
+ reporter.report({'precision': precision}, self.discriminator)
+ reporter.report({'recall': recall}, self.discriminator)
+
loss = {'predictor': loss_l1 * self.loss_config.l1}
if self.aligner is not None:
diff --git a/train.py b/train.py
index a80c011..f62d338 100644
--- a/train.py
+++ b/train.py
@@ -88,6 +88,7 @@ if extensions.PlotReport.available():
'predictor/l1',
'test/predictor/loss',
'train/predictor/loss',
+ 'discriminator/accuracy',
'discriminator/fake',
'discriminator/true',
'discriminator/grad',