diff options
Diffstat (limited to 'become_yukarin/dataset/dataset.py')
| -rw-r--r-- | become_yukarin/dataset/dataset.py | 29 |
1 files changed, 23 insertions, 6 deletions
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py index f9db53e..39331c3 100644 --- a/become_yukarin/dataset/dataset.py +++ b/become_yukarin/dataset/dataset.py @@ -128,10 +128,20 @@ class ReshapeFeatureProcess(BaseDataProcess): def __call__(self, data: AcousticFeature, test): feature = numpy.concatenate([getattr(data, t) for t in self._targets]) - feature = feature[numpy.newaxis] + feature = feature.T return feature +class ShapeAlignProcess(BaseDataProcess): + def __call__(self, data, test): + data1, data2 = data['input'], data['target'] + m = max(data1.shape[1], data2.shape[1]) + data1 = numpy.pad(data1, ((0, 0), (0, m - data1.shape[1])), mode='constant') + data2 = numpy.pad(data2, ((0, 0), (0, m - data2.shape[1])), mode='constant') + data['input'], data['target'] = data1, data2 + return data + + class DataProcessDataset(chainer.dataset.DatasetMixin): def __init__(self, data: typing.List, data_process: BaseDataProcess): self._data = data @@ -144,28 +154,35 @@ class DataProcessDataset(chainer.dataset.DatasetMixin): return self._data_process(data=self._data[i], test=not chainer.config.train) -def choose(config: DatasetConfig): +def create(config: DatasetConfig): import glob input_paths = list(sorted([Path(p) for p in glob.glob(config.input_glob)])) target_paths = list(sorted([Path(p) for p in glob.glob(config.target_glob)])) assert len(input_paths) == len(target_paths) + acoustic_feature_load_process = AcousticFeatureLoadProcess() + input_mean = acoustic_feature_load_process(config.input_mean_path, test=True) + input_var = acoustic_feature_load_process(config.input_var_path, test=True) + target_mean = acoustic_feature_load_process(config.target_mean_path, test=True) + target_var = acoustic_feature_load_process(config.target_var_path, test=True) + # {input_path, target_path} data_process = ChainProcess([ SplitProcess(dict( input=ChainProcess([ LambdaProcess(lambda d, test: d['input_path']), - AcousticFeatureLoadProcess(), - AcousticFeatureNormalizeProcess(mean=config.input_mean, var=config.input_var), + acoustic_feature_load_process, + AcousticFeatureNormalizeProcess(mean=input_mean, var=input_var), ReshapeFeatureProcess(['mfcc']), ]), target=ChainProcess([ LambdaProcess(lambda d, test: d['target_path']), - AcousticFeatureLoadProcess(), - AcousticFeatureNormalizeProcess(mean=config.target_mean, var=config.target_var), + acoustic_feature_load_process, + AcousticFeatureNormalizeProcess(mean=target_mean, var=target_var), ReshapeFeatureProcess(['mfcc']), ]), )), + ShapeAlignProcess(), ]) num_test = config.num_test |
