summaryrefslogtreecommitdiff
path: root/become_yukarin/dataset/dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/dataset/dataset.py')
-rw-r--r--become_yukarin/dataset/dataset.py54
1 files changed, 49 insertions, 5 deletions
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py
index 6328a1c..a6550b7 100644
--- a/become_yukarin/dataset/dataset.py
+++ b/become_yukarin/dataset/dataset.py
@@ -1,3 +1,4 @@
+import copy
import typing
from abc import ABCMeta, abstractmethod
from collections import defaultdict
@@ -41,13 +42,16 @@ class DictKeyReplaceProcess(BaseDataProcess):
class ChainProcess(BaseDataProcess):
def __init__(self, process: typing.Iterable[BaseDataProcess]):
- self._process = process
+ self._process = list(process)
def __call__(self, data, test):
for p in self._process:
data = p(data, test)
return data
+ def append(self, process: BaseDataProcess):
+ self._process.append(process)
+
class SplitProcess(BaseDataProcess):
def __init__(self, process: typing.Dict[str, typing.Optional[BaseDataProcess]]):
@@ -248,6 +252,24 @@ class ShapeAlignProcess(BaseDataProcess):
return data
+class CropProcess(BaseDataProcess):
+ def __init__(self, crop_size: int, time_axis: int = 1):
+ self._crop_size = crop_size
+ self._time_axis = time_axis
+
+ def __call__(self, datas: Dict[str, any], test=True):
+ assert not test
+
+ data, seed = datas['data'], datas['seed']
+ random = numpy.random.RandomState(seed)
+
+ len_time = data.shape[self._time_axis]
+ assert len_time >= self._crop_size
+
+ start = random.randint(len_time - self._crop_size + 1)
+ return numpy.split(data, [start, start + self._crop_size], axis=self._time_axis)[1]
+
+
class DataProcessDataset(chainer.dataset.DatasetMixin):
def __init__(self, data: typing.List, data_process: BaseDataProcess):
self._data = data
@@ -273,7 +295,7 @@ def create(config: DatasetConfig):
target_var = acoustic_feature_load_process(config.target_var_path, test=True)
# {input_path, target_path}
- data_process = ChainProcess([
+ data_process_base = ChainProcess([
SplitProcess(dict(
input=ChainProcess([
LambdaProcess(lambda d, test: d['input_path']),
@@ -300,6 +322,28 @@ def create(config: DatasetConfig):
ShapeAlignProcess(),
])
+ data_process_train = copy.deepcopy(data_process_base)
+ if config.train_crop_size is not None:
+ data_process_train.append(ChainProcess([
+ LambdaProcess(lambda d, test: dict(seed=numpy.random.randint(2 ** 32), **d)),
+ SplitProcess(dict(
+ input=ChainProcess([
+ LambdaProcess(lambda d, test: dict(data=d['input'], seed=d['seed'])),
+ CropProcess(crop_size=config.train_crop_size),
+ ]),
+ target=ChainProcess([
+ LambdaProcess(lambda d, test: dict(data=d['target'], seed=d['seed'])),
+ CropProcess(crop_size=config.train_crop_size),
+ ]),
+ mask=ChainProcess([
+ LambdaProcess(lambda d, test: dict(data=d['mask'], seed=d['seed'])),
+ CropProcess(crop_size=config.train_crop_size),
+ ]),
+ )),
+ ]))
+
+ data_process_test = data_process_base
+
num_test = config.num_test
pairs = [
dict(input_path=input_path, target_path=target_path)
@@ -311,7 +355,7 @@ def create(config: DatasetConfig):
train_for_evaluate_paths = train_paths[:num_test]
return {
- 'train': DataProcessDataset(train_paths, data_process),
- 'test': DataProcessDataset(test_paths, data_process),
- 'train_eval': DataProcessDataset(train_for_evaluate_paths, data_process),
+ 'train': DataProcessDataset(train_paths, data_process_train),
+ 'test': DataProcessDataset(test_paths, data_process_test),
+ 'train_eval': DataProcessDataset(train_for_evaluate_paths, data_process_test),
}