summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--become_yukarin/config/sr_config.py4
-rw-r--r--become_yukarin/dataset/dataset.py29
-rw-r--r--train.py17
-rw-r--r--train_sr.py15
4 files changed, 25 insertions, 40 deletions
diff --git a/become_yukarin/config/sr_config.py b/become_yukarin/config/sr_config.py
index 4c62808..266ea04 100644
--- a/become_yukarin/config/sr_config.py
+++ b/become_yukarin/config/sr_config.py
@@ -12,6 +12,8 @@ class SRDatasetConfig(NamedTuple):
param: Param
input_glob: Path
train_crop_size: int
+ input_global_noise: float
+ input_local_noise: float
seed: int
num_test: int
@@ -75,6 +77,8 @@ def create_from_json(s: Union[str, Path]):
param=Param(),
input_glob=Path(d['dataset']['input_glob']),
train_crop_size=d['dataset']['train_crop_size'],
+ input_global_noise=d['dataset']['input_global_noise'],
+ input_local_noise=d['dataset']['input_local_noise'],
seed=d['dataset']['seed'],
num_test=d['dataset']['num_test'],
),
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py
index ab05568..90dcd4a 100644
--- a/become_yukarin/dataset/dataset.py
+++ b/become_yukarin/dataset/dataset.py
@@ -551,10 +551,22 @@ def create_sr(config: SRDatasetConfig):
add_seed(),
SplitProcess(dict(input=crop('input'), target=crop('target'))),
]))
- data_process_train.append(LambdaProcess(lambda d, test: {
- 'input': d['input'][numpy.newaxis],
- 'target': d['target'][numpy.newaxis],
- }))
+
+ # add noise
+ data_process_train.append(SplitProcess(dict(
+ input=ChainProcess([
+ LambdaProcess(lambda d, test: d['input']),
+ AddNoiseProcess(p_global=config.input_global_noise, p_local=config.input_local_noise),
+ ]),
+ target=ChainProcess([
+ LambdaProcess(lambda d, test: d['target']),
+ ]),
+ )))
+
+ data_process_train.append(LambdaProcess(lambda d, test: {
+ 'input': d['input'][numpy.newaxis],
+ 'target': d['target'][numpy.newaxis],
+ }))
data_process_test = copy.deepcopy(data_process_base)
if config.train_crop_size is not None:
@@ -570,10 +582,11 @@ def create_sr(config: SRDatasetConfig):
FirstCropProcess(crop_size=config.train_crop_size, time_axis=0),
]),
)))
- data_process_test.append(LambdaProcess(lambda d, test: {
- 'input': d['input'][numpy.newaxis],
- 'target': d['target'][numpy.newaxis],
- }))
+
+ data_process_test.append(LambdaProcess(lambda d, test: {
+ 'input': d['input'][numpy.newaxis],
+ 'target': d['target'][numpy.newaxis],
+ }))
input_paths = list(sorted([Path(p) for p in glob.glob(str(config.input_glob))]))
diff --git a/train.py b/train.py
index 3e8cced..26490ce 100644
--- a/train.py
+++ b/train.py
@@ -81,22 +81,5 @@ trainer.extend(ext, trigger=trigger_snapshot)
trainer.extend(extensions.LogReport(trigger=trigger_log))
-if extensions.PlotReport.available():
- trainer.extend(extensions.PlotReport(
- y_keys=[
- 'predictor/loss',
- 'predictor/l1',
- 'test/predictor/loss',
- 'train/predictor/loss',
- 'discriminator/accuracy',
- 'discriminator/fake',
- 'discriminator/true',
- 'discriminator/grad',
- ],
- x_key='iteration',
- file_name='loss.png',
- trigger=trigger_log,
- ))
-
save_args(arguments, arguments.output)
trainer.run()
diff --git a/train_sr.py b/train_sr.py
index 96f11e7..40f311a 100644
--- a/train_sr.py
+++ b/train_sr.py
@@ -80,20 +80,5 @@ trainer.extend(ext, trigger=trigger_snapshot)
trainer.extend(extensions.LogReport(trigger=trigger_log))
trainer.extend(extensions.PrintReport(['predictor/loss']))
-if extensions.PlotReport.available():
- trainer.extend(extensions.PlotReport(
- y_keys=[
- 'predictor/loss',
- 'predictor/mse',
- 'predictor/adversarial',
- 'discriminator/accuracy',
- 'discriminator/fake',
- 'discriminator/real',
- ],
- x_key='iteration',
- file_name='loss.png',
- trigger=trigger_log,
- ))
-
save_args(arguments, arguments.output)
trainer.run()