diff options
| -rw-r--r-- | become_yukarin/config/sr_config.py | 4 | ||||
| -rw-r--r-- | become_yukarin/dataset/dataset.py | 29 | ||||
| -rw-r--r-- | train.py | 17 | ||||
| -rw-r--r-- | train_sr.py | 15 |
4 files changed, 25 insertions, 40 deletions
diff --git a/become_yukarin/config/sr_config.py b/become_yukarin/config/sr_config.py index 4c62808..266ea04 100644 --- a/become_yukarin/config/sr_config.py +++ b/become_yukarin/config/sr_config.py @@ -12,6 +12,8 @@ class SRDatasetConfig(NamedTuple): param: Param input_glob: Path train_crop_size: int + input_global_noise: float + input_local_noise: float seed: int num_test: int @@ -75,6 +77,8 @@ def create_from_json(s: Union[str, Path]): param=Param(), input_glob=Path(d['dataset']['input_glob']), train_crop_size=d['dataset']['train_crop_size'], + input_global_noise=d['dataset']['input_global_noise'], + input_local_noise=d['dataset']['input_local_noise'], seed=d['dataset']['seed'], num_test=d['dataset']['num_test'], ), diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py index ab05568..90dcd4a 100644 --- a/become_yukarin/dataset/dataset.py +++ b/become_yukarin/dataset/dataset.py @@ -551,10 +551,22 @@ def create_sr(config: SRDatasetConfig): add_seed(), SplitProcess(dict(input=crop('input'), target=crop('target'))), ])) - data_process_train.append(LambdaProcess(lambda d, test: { - 'input': d['input'][numpy.newaxis], - 'target': d['target'][numpy.newaxis], - })) + + # add noise + data_process_train.append(SplitProcess(dict( + input=ChainProcess([ + LambdaProcess(lambda d, test: d['input']), + AddNoiseProcess(p_global=config.input_global_noise, p_local=config.input_local_noise), + ]), + target=ChainProcess([ + LambdaProcess(lambda d, test: d['target']), + ]), + ))) + + data_process_train.append(LambdaProcess(lambda d, test: { + 'input': d['input'][numpy.newaxis], + 'target': d['target'][numpy.newaxis], + })) data_process_test = copy.deepcopy(data_process_base) if config.train_crop_size is not None: @@ -570,10 +582,11 @@ def create_sr(config: SRDatasetConfig): FirstCropProcess(crop_size=config.train_crop_size, time_axis=0), ]), ))) - data_process_test.append(LambdaProcess(lambda d, test: { - 'input': d['input'][numpy.newaxis], - 'target': d['target'][numpy.newaxis], - })) + + data_process_test.append(LambdaProcess(lambda d, test: { + 'input': d['input'][numpy.newaxis], + 'target': d['target'][numpy.newaxis], + })) input_paths = list(sorted([Path(p) for p in glob.glob(str(config.input_glob))])) @@ -81,22 +81,5 @@ trainer.extend(ext, trigger=trigger_snapshot) trainer.extend(extensions.LogReport(trigger=trigger_log)) -if extensions.PlotReport.available(): - trainer.extend(extensions.PlotReport( - y_keys=[ - 'predictor/loss', - 'predictor/l1', - 'test/predictor/loss', - 'train/predictor/loss', - 'discriminator/accuracy', - 'discriminator/fake', - 'discriminator/true', - 'discriminator/grad', - ], - x_key='iteration', - file_name='loss.png', - trigger=trigger_log, - )) - save_args(arguments, arguments.output) trainer.run() diff --git a/train_sr.py b/train_sr.py index 96f11e7..40f311a 100644 --- a/train_sr.py +++ b/train_sr.py @@ -80,20 +80,5 @@ trainer.extend(ext, trigger=trigger_snapshot) trainer.extend(extensions.LogReport(trigger=trigger_log)) trainer.extend(extensions.PrintReport(['predictor/loss'])) -if extensions.PlotReport.available(): - trainer.extend(extensions.PlotReport( - y_keys=[ - 'predictor/loss', - 'predictor/mse', - 'predictor/adversarial', - 'discriminator/accuracy', - 'discriminator/fake', - 'discriminator/real', - ], - x_key='iteration', - file_name='loss.png', - trigger=trigger_log, - )) - save_args(arguments, arguments.output) trainer.run() |
