summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <hihokaruta@gmail.com>2018-02-18 13:57:22 +0900
committerHiroshiba Kazuyuki <hihokaruta@gmail.com>2018-02-18 13:57:22 +0900
commit348834f918b6d644ef9809dde0f7205ba5a364f2 (patch)
treec1edb96adebaa0498924215d7b8d9f9ae2974a3c
parent0faa1022b927d53bbb4fb90d6960cf6dc337b58a (diff)
ライブラリの依存をなくした
-rw-r--r--become_yukarin/dataset/utility.py74
-rw-r--r--requirements.txt1
-rw-r--r--setup.py1
3 files changed, 71 insertions, 5 deletions
diff --git a/become_yukarin/dataset/utility.py b/become_yukarin/dataset/utility.py
index ca68acf..7a97967 100644
--- a/become_yukarin/dataset/utility.py
+++ b/become_yukarin/dataset/utility.py
@@ -1,7 +1,75 @@
+import math
+
import fastdtw
-import nnmnkwii.metrics
import numpy
-import scipy.interpolate
+
+_logdb_const = 10.0 / numpy.log(10.0) * numpy.sqrt(2.0)
+
+
+# should work on torch and numpy arrays
+def _sqrt(x):
+ isnumpy = isinstance(x, numpy.ndarray)
+ isscalar = numpy.isscalar(x)
+ return numpy.sqrt(x) if isnumpy else math.sqrt(x) if isscalar else x.sqrt()
+
+
+def _exp(x):
+ isnumpy = isinstance(x, numpy.ndarray)
+ isscalar = numpy.isscalar(x)
+ return numpy.exp(x) if isnumpy else math.exp(x) if isscalar else x.exp()
+
+
+def _sum(x):
+ if isinstance(x, list) or isinstance(x, numpy.ndarray):
+ return numpy.sum(x)
+ return float(x.sum())
+
+
+def melcd(X, Y, lengths=None):
+ """Mel-cepstrum distortion (MCD).
+
+ The function computes MCD for time-aligned mel-cepstrum sequences.
+
+ Args:
+ X (ndarray): Input mel-cepstrum, shape can be either of
+ (``D``,), (``T x D``) or (``B x T x D``). Both Numpy and torch arrays
+ are supported.
+ Y (ndarray): Target mel-cepstrum, shape can be either of
+ (``D``,), (``T x D``) or (``B x T x D``). Both Numpy and torch arrays
+ are supported.
+ lengths (list): Lengths of padded inputs. This should only be specified
+ if you give mini-batch inputs.
+
+ Returns:
+ float: Mean mel-cepstrum distortion in dB.
+
+ .. note::
+
+ The function doesn't check if inputs are actually mel-cepstrum.
+ """
+ # summing against feature axis, and then take mean against time axis
+ # Eq. (1a)
+ # https://www.cs.cmu.edu/~awb/papers/sltu2008/kominek_black.sltu_2008.pdf
+ if lengths is None:
+ z = X - Y
+ r = _sqrt((z * z).sum(-1))
+ if not numpy.isscalar(r):
+ r = r.mean()
+ return _logdb_const * r
+
+ # Case for 1-dim features.
+ if len(X.shape) == 2:
+ # Add feature axis
+ X, Y = X[:, :, None], Y[:, :, None]
+
+ s = 0.0
+ T = _sum(lengths)
+ for x, y, length in zip(X, Y, lengths):
+ x, y = x[:length], y[:length]
+ z = x - y
+ s += _sqrt((z * z).sum(-1)).sum()
+
+ return _logdb_const * s / T
class DTWAligner(object):
@@ -43,7 +111,7 @@ class MFCCAligner(DTWAligner):
def __init__(self, x, y, *args, **kwargs) -> None:
x = self._calc_aligner_feature(x)
y = self._calc_aligner_feature(y)
- kwargs.update(dist=nnmnkwii.metrics.melcd)
+ kwargs.update(dist=melcd)
super().__init__(x, y, *args, **kwargs)
@classmethod
diff --git a/requirements.txt b/requirements.txt
index 19ca71a..da0242a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,6 +5,5 @@ librosa
pysptk
pyworld
fastdtw
-nnmnkwii
matplotlib
chainerui
diff --git a/setup.py b/setup.py
index dc78f0c..8abd45a 100644
--- a/setup.py
+++ b/setup.py
@@ -16,7 +16,6 @@ setup(
'pysptk',
'pyworld',
'fastdtw',
- 'nnmnkwii',
'chainerui',
],
classifiers=[