summaryrefslogtreecommitdiff
path: root/train_sr.py
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <hihokaruta@gmail.com>2018-01-14 07:40:07 +0900
committerHiroshiba Kazuyuki <hihokaruta@gmail.com>2018-01-14 07:40:07 +0900
commit2be3f03adc5695f82c6ab86da780108f786ed014 (patch)
treeae4b95aa3e45706598e66cc00ff5ad9f00ef97a9 /train_sr.py
parentf9185301a22f1632b16dd5266197bb40cb7c302e (diff)
超解像
Diffstat (limited to 'train_sr.py')
-rw-r--r--train_sr.py98
1 files changed, 98 insertions, 0 deletions
diff --git a/train_sr.py b/train_sr.py
new file mode 100644
index 0000000..c714aa0
--- /dev/null
+++ b/train_sr.py
@@ -0,0 +1,98 @@
+import argparse
+from functools import partial
+from pathlib import Path
+
+from chainer import cuda
+from chainer import optimizers
+from chainer import training
+from chainer.dataset import convert
+from chainer.iterators import MultiprocessIterator
+from chainer.training import extensions
+from chainerui.utils import save_args
+
+from become_yukarin.config.sr_config import create_from_json
+from become_yukarin.dataset import create_sr as create_sr_dataset
+from become_yukarin.model.sr_model import create_sr as create_sr_model
+from become_yukarin.updater.sr_updater import SRUpdater
+
+parser = argparse.ArgumentParser()
+parser.add_argument('config_json_path', type=Path)
+parser.add_argument('output', type=Path)
+arguments = parser.parse_args()
+
+config = create_from_json(arguments.config_json_path)
+arguments.output.mkdir(exist_ok=True)
+config.save_as_json((arguments.output / 'config.json').absolute())
+
+# model
+if config.train.gpu >= 0:
+ cuda.get_device_from_id(config.train.gpu).use()
+predictor, discriminator = create_sr_model(config.model)
+models = {
+ 'predictor': predictor,
+ 'discriminator': discriminator,
+}
+
+# dataset
+dataset = create_sr_dataset(config.dataset)
+train_iter = MultiprocessIterator(dataset['train'], config.train.batchsize)
+test_iter = MultiprocessIterator(dataset['test'], config.train.batchsize, repeat=False, shuffle=False)
+train_eval_iter = MultiprocessIterator(dataset['train_eval'], config.train.batchsize, repeat=False, shuffle=False)
+
+
+# optimizer
+def create_optimizer(model):
+ optimizer = optimizers.Adam(alpha=0.0002, beta1=0.5, beta2=0.999)
+ optimizer.setup(model)
+ return optimizer
+
+
+opts = {key: create_optimizer(model) for key, model in models.items()}
+
+# updater
+converter = partial(convert.concat_examples, padding=0)
+updater = SRUpdater(
+ loss_config=config.loss,
+ predictor=predictor,
+ discriminator=discriminator,
+ device=config.train.gpu,
+ iterator=train_iter,
+ optimizer=opts,
+ converter=converter,
+)
+
+# trainer
+trigger_log = (config.train.log_iteration, 'iteration')
+trigger_snapshot = (config.train.snapshot_iteration, 'iteration')
+
+trainer = training.Trainer(updater, out=arguments.output)
+
+ext = extensions.Evaluator(test_iter, models, converter, device=config.train.gpu, eval_func=updater.forward)
+trainer.extend(ext, name='test', trigger=trigger_log)
+ext = extensions.Evaluator(train_eval_iter, models, converter, device=config.train.gpu, eval_func=updater.forward)
+trainer.extend(ext, name='train', trigger=trigger_log)
+
+trainer.extend(extensions.dump_graph('predictor/loss'))
+
+ext = extensions.snapshot_object(predictor, filename='predictor_{.updater.iteration}.npz')
+trainer.extend(ext, trigger=trigger_snapshot)
+
+trainer.extend(extensions.LogReport(trigger=trigger_log))
+
+if extensions.PlotReport.available():
+ trainer.extend(extensions.PlotReport(
+ y_keys=[
+ 'predictor/loss',
+ 'predictor/mse',
+ 'predictor/adversarial',
+ 'discriminator/accuracy',
+ 'discriminator/fake',
+ 'discriminator/real',
+ ],
+ x_key='iteration',
+ file_name='loss.png',
+ trigger=trigger_log,
+ ))
+
+save_args(arguments, arguments.output)
+trainer.run()