summaryrefslogtreecommitdiff
path: root/become_yukarin/dataset/dataset.py
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2018-01-16 16:02:59 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2018-01-16 16:02:59 +0900
commit86079f0cea1f79beb7cbbec08f6c19191929207a (patch)
tree69ebc560ec11cfba7a0e703f46836b3d7855095a /become_yukarin/dataset/dataset.py
parent83608c12e7bb28df1966cbe5b9d86a8e23175044 (diff)
add noise
Diffstat (limited to 'become_yukarin/dataset/dataset.py')
-rw-r--r--become_yukarin/dataset/dataset.py29
1 files changed, 21 insertions, 8 deletions
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py
index ab05568..90dcd4a 100644
--- a/become_yukarin/dataset/dataset.py
+++ b/become_yukarin/dataset/dataset.py
@@ -551,10 +551,22 @@ def create_sr(config: SRDatasetConfig):
add_seed(),
SplitProcess(dict(input=crop('input'), target=crop('target'))),
]))
- data_process_train.append(LambdaProcess(lambda d, test: {
- 'input': d['input'][numpy.newaxis],
- 'target': d['target'][numpy.newaxis],
- }))
+
+ # add noise
+ data_process_train.append(SplitProcess(dict(
+ input=ChainProcess([
+ LambdaProcess(lambda d, test: d['input']),
+ AddNoiseProcess(p_global=config.input_global_noise, p_local=config.input_local_noise),
+ ]),
+ target=ChainProcess([
+ LambdaProcess(lambda d, test: d['target']),
+ ]),
+ )))
+
+ data_process_train.append(LambdaProcess(lambda d, test: {
+ 'input': d['input'][numpy.newaxis],
+ 'target': d['target'][numpy.newaxis],
+ }))
data_process_test = copy.deepcopy(data_process_base)
if config.train_crop_size is not None:
@@ -570,10 +582,11 @@ def create_sr(config: SRDatasetConfig):
FirstCropProcess(crop_size=config.train_crop_size, time_axis=0),
]),
)))
- data_process_test.append(LambdaProcess(lambda d, test: {
- 'input': d['input'][numpy.newaxis],
- 'target': d['target'][numpy.newaxis],
- }))
+
+ data_process_test.append(LambdaProcess(lambda d, test: {
+ 'input': d['input'][numpy.newaxis],
+ 'target': d['target'][numpy.newaxis],
+ }))
input_paths = list(sorted([Path(p) for p in glob.glob(str(config.input_glob))]))