diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-20 03:06:39 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-20 03:06:39 +0900 |
| commit | 16b4e72fe6728e2e64d4c6357b7c73ac06868c1c (patch) | |
| tree | 657f0398b9a237ab46327d08f58a230b9581669b /become_yukarin/dataset/dataset.py | |
| parent | 437a869590c989c184d33990b1d788149d073ee9 (diff) | |
aligner
Diffstat (limited to 'become_yukarin/dataset/dataset.py')
| -rw-r--r-- | become_yukarin/dataset/dataset.py | 102 |
1 files changed, 83 insertions, 19 deletions
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py index 83936b1..329226c 100644 --- a/become_yukarin/dataset/dataset.py +++ b/become_yukarin/dataset/dataset.py @@ -252,7 +252,41 @@ class ShapeAlignProcess(BaseDataProcess): return data -class CropProcess(BaseDataProcess): +class RandomPaddingProcess(BaseDataProcess): + def __init__(self, min_size: int, time_axis: int = 1): + assert time_axis == 1 + self._min_size = min_size + self._time_axis = time_axis + + def __call__(self, datas: Dict[str, any], test=True): + assert not test + + data, seed = datas['data'], datas['seed'] + random = numpy.random.RandomState(seed) + + if data.shape[self._time_axis] >= self._min_size: + return data + + pre = random.randint(self._min_size - data.shape[self._time_axis] + 1) + post = self._min_size - pre + return numpy.pad(data, ((0, 0), (pre, post)), mode='constant') + + +class LastPaddingProcess(BaseDataProcess): + def __init__(self, min_size: int, time_axis: int = 1): + assert time_axis == 1 + self._min_size = min_size + self._time_axis = time_axis + + def __call__(self, data: numpy.ndarray, test=None): + if data.shape[self._time_axis] >= self._min_size: + return data + + pre = self._min_size - data.shape[self._time_axis] + return numpy.pad(data, ((0, 0), (pre, 0)), mode='constant') + + +class RandomCropProcess(BaseDataProcess): def __init__(self, crop_size: int, time_axis: int = 1): self._crop_size = crop_size self._time_axis = time_axis @@ -270,6 +304,15 @@ class CropProcess(BaseDataProcess): return numpy.split(data, [start, start + self._crop_size], axis=self._time_axis)[1] +class FirstCropProcess(BaseDataProcess): + def __init__(self, crop_size: int, time_axis: int = 1): + self._crop_size = crop_size + self._time_axis = time_axis + + def __call__(self, data: numpy.ndarray, test=None): + return numpy.split(data, [0, self._crop_size], axis=self._time_axis)[1] + + class AddNoiseProcess(BaseDataProcess): def __init__(self, p_global: float = None, p_local: float = None): assert p_global is None or 0 <= p_global @@ -338,24 +381,28 @@ def create(config: DatasetConfig): ]) data_process_train = copy.deepcopy(data_process_base) - if config.train_crop_size is not None: - data_process_train.append(ChainProcess([ - LambdaProcess(lambda d, test: dict(seed=numpy.random.randint(2 ** 32), **d)), - SplitProcess(dict( - input=ChainProcess([ - LambdaProcess(lambda d, test: dict(data=d['input'], seed=d['seed'])), - CropProcess(crop_size=config.train_crop_size), - ]), - target=ChainProcess([ - LambdaProcess(lambda d, test: dict(data=d['target'], seed=d['seed'])), - CropProcess(crop_size=config.train_crop_size), - ]), - mask=ChainProcess([ - LambdaProcess(lambda d, test: dict(data=d['mask'], seed=d['seed'])), - CropProcess(crop_size=config.train_crop_size), - ]), - )), - ])) + + def add_seed(): + return LambdaProcess(lambda d, test: dict(seed=numpy.random.randint(2 ** 32), **d)) + + def padding(s): + return ChainProcess([ + LambdaProcess(lambda d, test: dict(data=d[s], seed=d['seed'])), + RandomPaddingProcess(min_size=config.train_crop_size), + ]) + + def crop(s): + return ChainProcess([ + LambdaProcess(lambda d, test: dict(data=d[s], seed=d['seed'])), + RandomCropProcess(crop_size=config.train_crop_size), + ]) + + data_process_train.append(ChainProcess([ + add_seed(), + SplitProcess(dict(input=padding('input'), target=padding('target'), mask=padding('mask'))), + add_seed(), + SplitProcess(dict(input=crop('input'), target=crop('target'), mask=crop('mask'))), + ])) # add noise data_process_train.append(SplitProcess(dict( @@ -373,6 +420,23 @@ def create(config: DatasetConfig): ))) data_process_test = data_process_base + data_process_test.append(SplitProcess(dict( + input=ChainProcess([ + LambdaProcess(lambda d, test: d['input']), + LastPaddingProcess(min_size=config.train_crop_size), + FirstCropProcess(crop_size=config.train_crop_size), + ]), + target=ChainProcess([ + LambdaProcess(lambda d, test: d['target']), + LastPaddingProcess(min_size=config.train_crop_size), + FirstCropProcess(crop_size=config.train_crop_size), + ]), + mask=ChainProcess([ + LambdaProcess(lambda d, test: d['mask']), + LastPaddingProcess(min_size=config.train_crop_size), + FirstCropProcess(crop_size=config.train_crop_size), + ]), + ))) num_test = config.num_test pairs = [ |
