summaryrefslogtreecommitdiff
path: root/become_yukarin/dataset
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-20 03:06:39 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-20 03:06:39 +0900
commit16b4e72fe6728e2e64d4c6357b7c73ac06868c1c (patch)
tree657f0398b9a237ab46327d08f58a230b9581669b /become_yukarin/dataset
parent437a869590c989c184d33990b1d788149d073ee9 (diff)
aligner
Diffstat (limited to 'become_yukarin/dataset')
-rw-r--r--become_yukarin/dataset/dataset.py102
1 files changed, 83 insertions, 19 deletions
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py
index 83936b1..329226c 100644
--- a/become_yukarin/dataset/dataset.py
+++ b/become_yukarin/dataset/dataset.py
@@ -252,7 +252,41 @@ class ShapeAlignProcess(BaseDataProcess):
return data
-class CropProcess(BaseDataProcess):
+class RandomPaddingProcess(BaseDataProcess):
+ def __init__(self, min_size: int, time_axis: int = 1):
+ assert time_axis == 1
+ self._min_size = min_size
+ self._time_axis = time_axis
+
+ def __call__(self, datas: Dict[str, any], test=True):
+ assert not test
+
+ data, seed = datas['data'], datas['seed']
+ random = numpy.random.RandomState(seed)
+
+ if data.shape[self._time_axis] >= self._min_size:
+ return data
+
+ pre = random.randint(self._min_size - data.shape[self._time_axis] + 1)
+ post = self._min_size - pre
+ return numpy.pad(data, ((0, 0), (pre, post)), mode='constant')
+
+
+class LastPaddingProcess(BaseDataProcess):
+ def __init__(self, min_size: int, time_axis: int = 1):
+ assert time_axis == 1
+ self._min_size = min_size
+ self._time_axis = time_axis
+
+ def __call__(self, data: numpy.ndarray, test=None):
+ if data.shape[self._time_axis] >= self._min_size:
+ return data
+
+ pre = self._min_size - data.shape[self._time_axis]
+ return numpy.pad(data, ((0, 0), (pre, 0)), mode='constant')
+
+
+class RandomCropProcess(BaseDataProcess):
def __init__(self, crop_size: int, time_axis: int = 1):
self._crop_size = crop_size
self._time_axis = time_axis
@@ -270,6 +304,15 @@ class CropProcess(BaseDataProcess):
return numpy.split(data, [start, start + self._crop_size], axis=self._time_axis)[1]
+class FirstCropProcess(BaseDataProcess):
+ def __init__(self, crop_size: int, time_axis: int = 1):
+ self._crop_size = crop_size
+ self._time_axis = time_axis
+
+ def __call__(self, data: numpy.ndarray, test=None):
+ return numpy.split(data, [0, self._crop_size], axis=self._time_axis)[1]
+
+
class AddNoiseProcess(BaseDataProcess):
def __init__(self, p_global: float = None, p_local: float = None):
assert p_global is None or 0 <= p_global
@@ -338,24 +381,28 @@ def create(config: DatasetConfig):
])
data_process_train = copy.deepcopy(data_process_base)
- if config.train_crop_size is not None:
- data_process_train.append(ChainProcess([
- LambdaProcess(lambda d, test: dict(seed=numpy.random.randint(2 ** 32), **d)),
- SplitProcess(dict(
- input=ChainProcess([
- LambdaProcess(lambda d, test: dict(data=d['input'], seed=d['seed'])),
- CropProcess(crop_size=config.train_crop_size),
- ]),
- target=ChainProcess([
- LambdaProcess(lambda d, test: dict(data=d['target'], seed=d['seed'])),
- CropProcess(crop_size=config.train_crop_size),
- ]),
- mask=ChainProcess([
- LambdaProcess(lambda d, test: dict(data=d['mask'], seed=d['seed'])),
- CropProcess(crop_size=config.train_crop_size),
- ]),
- )),
- ]))
+
+ def add_seed():
+ return LambdaProcess(lambda d, test: dict(seed=numpy.random.randint(2 ** 32), **d))
+
+ def padding(s):
+ return ChainProcess([
+ LambdaProcess(lambda d, test: dict(data=d[s], seed=d['seed'])),
+ RandomPaddingProcess(min_size=config.train_crop_size),
+ ])
+
+ def crop(s):
+ return ChainProcess([
+ LambdaProcess(lambda d, test: dict(data=d[s], seed=d['seed'])),
+ RandomCropProcess(crop_size=config.train_crop_size),
+ ])
+
+ data_process_train.append(ChainProcess([
+ add_seed(),
+ SplitProcess(dict(input=padding('input'), target=padding('target'), mask=padding('mask'))),
+ add_seed(),
+ SplitProcess(dict(input=crop('input'), target=crop('target'), mask=crop('mask'))),
+ ]))
# add noise
data_process_train.append(SplitProcess(dict(
@@ -373,6 +420,23 @@ def create(config: DatasetConfig):
)))
data_process_test = data_process_base
+ data_process_test.append(SplitProcess(dict(
+ input=ChainProcess([
+ LambdaProcess(lambda d, test: d['input']),
+ LastPaddingProcess(min_size=config.train_crop_size),
+ FirstCropProcess(crop_size=config.train_crop_size),
+ ]),
+ target=ChainProcess([
+ LambdaProcess(lambda d, test: d['target']),
+ LastPaddingProcess(min_size=config.train_crop_size),
+ FirstCropProcess(crop_size=config.train_crop_size),
+ ]),
+ mask=ChainProcess([
+ LambdaProcess(lambda d, test: d['mask']),
+ LastPaddingProcess(min_size=config.train_crop_size),
+ FirstCropProcess(crop_size=config.train_crop_size),
+ ]),
+ )))
num_test = config.num_test
pairs = [