diff options
Diffstat (limited to 'become_yukarin/dataset/dataset.py')
| -rw-r--r-- | become_yukarin/dataset/dataset.py | 25 |
1 files changed, 17 insertions, 8 deletions
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py index b0f9807..38cf749 100644 --- a/become_yukarin/dataset/dataset.py +++ b/become_yukarin/dataset/dataset.py @@ -313,7 +313,6 @@ class ShapeAlignProcess(BaseDataProcess): class RandomPaddingProcess(BaseDataProcess): def __init__(self, min_size: int, time_axis: int = 1): - assert time_axis == 1 self._min_size = min_size self._time_axis = time_axis @@ -328,7 +327,9 @@ class RandomPaddingProcess(BaseDataProcess): pre = random.randint(self._min_size - data.shape[self._time_axis] + 1) post = self._min_size - pre - return numpy.pad(data, ((0, 0), (pre, post)), mode='constant') + pad = [(0, 0)] * data.ndim + pad[self._time_axis] = (pre, post) + return numpy.pad(data, pad, mode='constant') class LastPaddingProcess(BaseDataProcess): @@ -520,8 +521,8 @@ def create_sr(config: SRDatasetConfig): data_process_base = ChainProcess([ LowHighSpectrogramFeatureLoadProcess(validate=True), SplitProcess(dict( - input=LambdaProcess(lambda d, test: numpy.log(d.low)), - target=LambdaProcess(lambda d, test: numpy.log(d.high)), + input=LambdaProcess(lambda d, test: numpy.log(d.low[:, :-1])), + target=LambdaProcess(lambda d, test: numpy.log(d.high[:, :-1])), )), ]) @@ -535,13 +536,13 @@ def create_sr(config: SRDatasetConfig): def padding(s): return ChainProcess([ LambdaProcess(lambda d, test: dict(data=d[s], seed=d['seed'])), - RandomPaddingProcess(min_size=config.train_crop_size), + RandomPaddingProcess(min_size=config.train_crop_size, time_axis=0), ]) def crop(s): return ChainProcess([ LambdaProcess(lambda d, test: dict(data=d[s], seed=d['seed'])), - RandomCropProcess(crop_size=config.train_crop_size), + RandomCropProcess(crop_size=config.train_crop_size, time_axis=0), ]) data_process_train.append(ChainProcess([ @@ -550,6 +551,10 @@ def create_sr(config: SRDatasetConfig): add_seed(), SplitProcess(dict(input=crop('input'), target=crop('target'))), ])) + data_process_train.append(LambdaProcess(lambda d, test: { + 'input': d['input'][numpy.newaxis], + 'target': d['target'][numpy.newaxis], + })) data_process_test = copy.deepcopy(data_process_base) if config.train_crop_size is not None: @@ -557,14 +562,18 @@ def create_sr(config: SRDatasetConfig): input=ChainProcess([ LambdaProcess(lambda d, test: d['input']), LastPaddingProcess(min_size=config.train_crop_size), - FirstCropProcess(crop_size=config.train_crop_size), + FirstCropProcess(crop_size=config.train_crop_size, time_axis=0), ]), target=ChainProcess([ LambdaProcess(lambda d, test: d['target']), LastPaddingProcess(min_size=config.train_crop_size), - FirstCropProcess(crop_size=config.train_crop_size), + FirstCropProcess(crop_size=config.train_crop_size, time_axis=0), ]), ))) + data_process_test.append(LambdaProcess(lambda d, test: { + 'input': d['input'][numpy.newaxis], + 'target': d['target'][numpy.newaxis], + })) input_paths = list(sorted([Path(p) for p in glob.glob(str(config.input_glob))])) |
