summaryrefslogtreecommitdiff
path: root/become_yukarin/dataset/dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/dataset/dataset.py')
-rw-r--r--become_yukarin/dataset/dataset.py29
1 files changed, 23 insertions, 6 deletions
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py
index f9db53e..39331c3 100644
--- a/become_yukarin/dataset/dataset.py
+++ b/become_yukarin/dataset/dataset.py
@@ -128,10 +128,20 @@ class ReshapeFeatureProcess(BaseDataProcess):
def __call__(self, data: AcousticFeature, test):
feature = numpy.concatenate([getattr(data, t) for t in self._targets])
- feature = feature[numpy.newaxis]
+ feature = feature.T
return feature
+class ShapeAlignProcess(BaseDataProcess):
+ def __call__(self, data, test):
+ data1, data2 = data['input'], data['target']
+ m = max(data1.shape[1], data2.shape[1])
+ data1 = numpy.pad(data1, ((0, 0), (0, m - data1.shape[1])), mode='constant')
+ data2 = numpy.pad(data2, ((0, 0), (0, m - data2.shape[1])), mode='constant')
+ data['input'], data['target'] = data1, data2
+ return data
+
+
class DataProcessDataset(chainer.dataset.DatasetMixin):
def __init__(self, data: typing.List, data_process: BaseDataProcess):
self._data = data
@@ -144,28 +154,35 @@ class DataProcessDataset(chainer.dataset.DatasetMixin):
return self._data_process(data=self._data[i], test=not chainer.config.train)
-def choose(config: DatasetConfig):
+def create(config: DatasetConfig):
import glob
input_paths = list(sorted([Path(p) for p in glob.glob(config.input_glob)]))
target_paths = list(sorted([Path(p) for p in glob.glob(config.target_glob)]))
assert len(input_paths) == len(target_paths)
+ acoustic_feature_load_process = AcousticFeatureLoadProcess()
+ input_mean = acoustic_feature_load_process(config.input_mean_path, test=True)
+ input_var = acoustic_feature_load_process(config.input_var_path, test=True)
+ target_mean = acoustic_feature_load_process(config.target_mean_path, test=True)
+ target_var = acoustic_feature_load_process(config.target_var_path, test=True)
+
# {input_path, target_path}
data_process = ChainProcess([
SplitProcess(dict(
input=ChainProcess([
LambdaProcess(lambda d, test: d['input_path']),
- AcousticFeatureLoadProcess(),
- AcousticFeatureNormalizeProcess(mean=config.input_mean, var=config.input_var),
+ acoustic_feature_load_process,
+ AcousticFeatureNormalizeProcess(mean=input_mean, var=input_var),
ReshapeFeatureProcess(['mfcc']),
]),
target=ChainProcess([
LambdaProcess(lambda d, test: d['target_path']),
- AcousticFeatureLoadProcess(),
- AcousticFeatureNormalizeProcess(mean=config.target_mean, var=config.target_var),
+ acoustic_feature_load_process,
+ AcousticFeatureNormalizeProcess(mean=target_mean, var=target_var),
ReshapeFeatureProcess(['mfcc']),
]),
)),
+ ShapeAlignProcess(),
])
num_test = config.num_test