diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2018-01-16 16:02:59 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2018-01-16 16:02:59 +0900 |
| commit | 86079f0cea1f79beb7cbbec08f6c19191929207a (patch) | |
| tree | 69ebc560ec11cfba7a0e703f46836b3d7855095a /become_yukarin/dataset/dataset.py | |
| parent | 83608c12e7bb28df1966cbe5b9d86a8e23175044 (diff) | |
add noise
Diffstat (limited to 'become_yukarin/dataset/dataset.py')
| -rw-r--r-- | become_yukarin/dataset/dataset.py | 29 |
1 files changed, 21 insertions, 8 deletions
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py index ab05568..90dcd4a 100644 --- a/become_yukarin/dataset/dataset.py +++ b/become_yukarin/dataset/dataset.py @@ -551,10 +551,22 @@ def create_sr(config: SRDatasetConfig): add_seed(), SplitProcess(dict(input=crop('input'), target=crop('target'))), ])) - data_process_train.append(LambdaProcess(lambda d, test: { - 'input': d['input'][numpy.newaxis], - 'target': d['target'][numpy.newaxis], - })) + + # add noise + data_process_train.append(SplitProcess(dict( + input=ChainProcess([ + LambdaProcess(lambda d, test: d['input']), + AddNoiseProcess(p_global=config.input_global_noise, p_local=config.input_local_noise), + ]), + target=ChainProcess([ + LambdaProcess(lambda d, test: d['target']), + ]), + ))) + + data_process_train.append(LambdaProcess(lambda d, test: { + 'input': d['input'][numpy.newaxis], + 'target': d['target'][numpy.newaxis], + })) data_process_test = copy.deepcopy(data_process_base) if config.train_crop_size is not None: @@ -570,10 +582,11 @@ def create_sr(config: SRDatasetConfig): FirstCropProcess(crop_size=config.train_crop_size, time_axis=0), ]), ))) - data_process_test.append(LambdaProcess(lambda d, test: { - 'input': d['input'][numpy.newaxis], - 'target': d['target'][numpy.newaxis], - })) + + data_process_test.append(LambdaProcess(lambda d, test: { + 'input': d['input'][numpy.newaxis], + 'target': d['target'][numpy.newaxis], + })) input_paths = list(sorted([Path(p) for p in glob.glob(str(config.input_glob))])) |
