summaryrefslogtreecommitdiff
path: root/become_yukarin/dataset/dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/dataset/dataset.py')
-rw-r--r--become_yukarin/dataset/dataset.py125
1 files changed, 118 insertions, 7 deletions
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py
index 368073d..b0f9807 100644
--- a/become_yukarin/dataset/dataset.py
+++ b/become_yukarin/dataset/dataset.py
@@ -1,4 +1,5 @@
import copy
+import glob
import typing
from abc import ABCMeta, abstractmethod
from collections import defaultdict
@@ -13,8 +14,10 @@ import numpy
import pysptk
import pyworld
-from ..config import DatasetConfig
+from ..config.config import DatasetConfig
+from ..config.sr_config import SRDatasetConfig
from ..data_struct import AcousticFeature
+from ..data_struct import LowHighSpectrogramFeature
from ..data_struct import Wave
@@ -112,6 +115,35 @@ class AcousticFeatureProcess(BaseDataProcess):
return feature
+class LowHighSpectrogramFeatureProcess(BaseDataProcess):
+ def __init__(self, frame_period, order, alpha, dtype=numpy.float32):
+ self._acoustic_feature_process = AcousticFeatureProcess(
+ frame_period=frame_period,
+ order=order,
+ alpha=alpha,
+ )
+ self._dtype = dtype
+ self._alpha = alpha
+
+ def __call__(self, data: Wave, test):
+ acoustic_feature = self._acoustic_feature_process(data, test=True).astype_only_float(self._dtype)
+ high_spectrogram = acoustic_feature.spectrogram
+
+ fftlen = pyworld.get_cheaptrick_fft_size(data.sampling_rate)
+ low_spectrogram = pysptk.mc2sp(
+ acoustic_feature.mfcc,
+ alpha=self._alpha,
+ fftlen=fftlen,
+ )
+
+ feature = LowHighSpectrogramFeature(
+ low=low_spectrogram,
+ high=high_spectrogram,
+ )
+ feature.validate()
+ return feature
+
+
class AcousticFeatureLoadProcess(BaseDataProcess):
def __init__(self, validate=False):
self._validate = validate
@@ -130,6 +162,21 @@ class AcousticFeatureLoadProcess(BaseDataProcess):
return feature
+class LowHighSpectrogramFeatureLoadProcess(BaseDataProcess):
+ def __init__(self, validate=False):
+ self._validate = validate
+
+ def __call__(self, path: Path, test=None):
+ d = numpy.load(path.expanduser()).item() # type: dict
+ feature = LowHighSpectrogramFeature(
+ low=d['low'],
+ high=d['high'],
+ )
+ if self._validate:
+ feature.validate()
+ return feature
+
+
class AcousticFeatureSaveProcess(BaseDataProcess):
def __init__(self, validate=False, ignore: List[str] = None):
self._validate = validate
@@ -353,11 +400,6 @@ class DataProcessDataset(chainer.dataset.DatasetMixin):
def create(config: DatasetConfig):
- import glob
- input_paths = list(sorted([Path(p) for p in glob.glob(str(config.input_glob))]))
- target_paths = list(sorted([Path(p) for p in glob.glob(str(config.target_glob))]))
- assert len(input_paths) == len(target_paths)
-
acoustic_feature_load_process = AcousticFeatureLoadProcess()
input_mean = acoustic_feature_load_process(config.input_mean_path, test=True)
input_var = acoustic_feature_load_process(config.input_var_path, test=True)
@@ -433,7 +475,7 @@ def create(config: DatasetConfig):
]),
)))
- data_process_test = data_process_base
+ data_process_test = copy.deepcopy(data_process_base)
if config.train_crop_size is not None:
data_process_test.append(SplitProcess(dict(
input=ChainProcess([
@@ -453,6 +495,10 @@ def create(config: DatasetConfig):
]),
)))
+ input_paths = list(sorted([Path(p) for p in glob.glob(str(config.input_glob))]))
+ target_paths = list(sorted([Path(p) for p in glob.glob(str(config.target_glob))]))
+ assert len(input_paths) == len(target_paths)
+
num_test = config.num_test
pairs = [
dict(input_path=input_path, target_path=target_path)
@@ -468,3 +514,68 @@ def create(config: DatasetConfig):
'test': DataProcessDataset(test_paths, data_process_test),
'train_eval': DataProcessDataset(train_for_evaluate_paths, data_process_test),
}
+
+
+def create_sr(config: SRDatasetConfig):
+ data_process_base = ChainProcess([
+ LowHighSpectrogramFeatureLoadProcess(validate=True),
+ SplitProcess(dict(
+ input=LambdaProcess(lambda d, test: numpy.log(d.low)),
+ target=LambdaProcess(lambda d, test: numpy.log(d.high)),
+ )),
+ ])
+
+ data_process_train = copy.deepcopy(data_process_base)
+
+ # cropping
+ if config.train_crop_size is not None:
+ def add_seed():
+ return LambdaProcess(lambda d, test: dict(seed=numpy.random.randint(2 ** 32), **d))
+
+ def padding(s):
+ return ChainProcess([
+ LambdaProcess(lambda d, test: dict(data=d[s], seed=d['seed'])),
+ RandomPaddingProcess(min_size=config.train_crop_size),
+ ])
+
+ def crop(s):
+ return ChainProcess([
+ LambdaProcess(lambda d, test: dict(data=d[s], seed=d['seed'])),
+ RandomCropProcess(crop_size=config.train_crop_size),
+ ])
+
+ data_process_train.append(ChainProcess([
+ add_seed(),
+ SplitProcess(dict(input=padding('input'), target=padding('target'))),
+ add_seed(),
+ SplitProcess(dict(input=crop('input'), target=crop('target'))),
+ ]))
+
+ data_process_test = copy.deepcopy(data_process_base)
+ if config.train_crop_size is not None:
+ data_process_test.append(SplitProcess(dict(
+ input=ChainProcess([
+ LambdaProcess(lambda d, test: d['input']),
+ LastPaddingProcess(min_size=config.train_crop_size),
+ FirstCropProcess(crop_size=config.train_crop_size),
+ ]),
+ target=ChainProcess([
+ LambdaProcess(lambda d, test: d['target']),
+ LastPaddingProcess(min_size=config.train_crop_size),
+ FirstCropProcess(crop_size=config.train_crop_size),
+ ]),
+ )))
+
+ input_paths = list(sorted([Path(p) for p in glob.glob(str(config.input_glob))]))
+
+ num_test = config.num_test
+ numpy.random.RandomState(config.seed).shuffle(input_paths)
+ train_paths = input_paths[num_test:]
+ test_paths = input_paths[:num_test]
+ train_for_evaluate_paths = train_paths[:num_test]
+
+ return {
+ 'train': DataProcessDataset(train_paths, data_process_train),
+ 'test': DataProcessDataset(test_paths, data_process_test),
+ 'train_eval': DataProcessDataset(train_for_evaluate_paths, data_process_test),
+ }