diff options
Diffstat (limited to 'become_yukarin/dataset/dataset.py')
| -rw-r--r-- | become_yukarin/dataset/dataset.py | 54 |
1 files changed, 49 insertions, 5 deletions
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py index 6328a1c..a6550b7 100644 --- a/become_yukarin/dataset/dataset.py +++ b/become_yukarin/dataset/dataset.py @@ -1,3 +1,4 @@ +import copy import typing from abc import ABCMeta, abstractmethod from collections import defaultdict @@ -41,13 +42,16 @@ class DictKeyReplaceProcess(BaseDataProcess): class ChainProcess(BaseDataProcess): def __init__(self, process: typing.Iterable[BaseDataProcess]): - self._process = process + self._process = list(process) def __call__(self, data, test): for p in self._process: data = p(data, test) return data + def append(self, process: BaseDataProcess): + self._process.append(process) + class SplitProcess(BaseDataProcess): def __init__(self, process: typing.Dict[str, typing.Optional[BaseDataProcess]]): @@ -248,6 +252,24 @@ class ShapeAlignProcess(BaseDataProcess): return data +class CropProcess(BaseDataProcess): + def __init__(self, crop_size: int, time_axis: int = 1): + self._crop_size = crop_size + self._time_axis = time_axis + + def __call__(self, datas: Dict[str, any], test=True): + assert not test + + data, seed = datas['data'], datas['seed'] + random = numpy.random.RandomState(seed) + + len_time = data.shape[self._time_axis] + assert len_time >= self._crop_size + + start = random.randint(len_time - self._crop_size + 1) + return numpy.split(data, [start, start + self._crop_size], axis=self._time_axis)[1] + + class DataProcessDataset(chainer.dataset.DatasetMixin): def __init__(self, data: typing.List, data_process: BaseDataProcess): self._data = data @@ -273,7 +295,7 @@ def create(config: DatasetConfig): target_var = acoustic_feature_load_process(config.target_var_path, test=True) # {input_path, target_path} - data_process = ChainProcess([ + data_process_base = ChainProcess([ SplitProcess(dict( input=ChainProcess([ LambdaProcess(lambda d, test: d['input_path']), @@ -300,6 +322,28 @@ def create(config: DatasetConfig): ShapeAlignProcess(), ]) + data_process_train = copy.deepcopy(data_process_base) + if config.train_crop_size is not None: + data_process_train.append(ChainProcess([ + LambdaProcess(lambda d, test: dict(seed=numpy.random.randint(2 ** 32), **d)), + SplitProcess(dict( + input=ChainProcess([ + LambdaProcess(lambda d, test: dict(data=d['input'], seed=d['seed'])), + CropProcess(crop_size=config.train_crop_size), + ]), + target=ChainProcess([ + LambdaProcess(lambda d, test: dict(data=d['target'], seed=d['seed'])), + CropProcess(crop_size=config.train_crop_size), + ]), + mask=ChainProcess([ + LambdaProcess(lambda d, test: dict(data=d['mask'], seed=d['seed'])), + CropProcess(crop_size=config.train_crop_size), + ]), + )), + ])) + + data_process_test = data_process_base + num_test = config.num_test pairs = [ dict(input_path=input_path, target_path=target_path) @@ -311,7 +355,7 @@ def create(config: DatasetConfig): train_for_evaluate_paths = train_paths[:num_test] return { - 'train': DataProcessDataset(train_paths, data_process), - 'test': DataProcessDataset(test_paths, data_process), - 'train_eval': DataProcessDataset(train_for_evaluate_paths, data_process), + 'train': DataProcessDataset(train_paths, data_process_train), + 'test': DataProcessDataset(test_paths, data_process_test), + 'train_eval': DataProcessDataset(train_for_evaluate_paths, data_process_test), } |
