summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <hihokaruta@gmail.com>2017-11-07 10:20:04 +0900
committerHiroshiba Kazuyuki <hihokaruta@gmail.com>2017-11-07 10:29:28 +0900
commit6119849270c2aed117627d7d2b060f37d1c25de4 (patch)
tree178e8df7a0c3d33f9de776b85ee4ff545f2ecdc3 /train.py
parent8e637c41a262373786b94d40a8f3559caf5cd44c (diff)
can train
Diffstat (limited to 'train.py')
-rw-r--r--train.py66
1 files changed, 66 insertions, 0 deletions
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..eb0833d
--- /dev/null
+++ b/train.py
@@ -0,0 +1,66 @@
+import argparse
+from pathlib import Path
+
+from chainer.iterators import MultiprocessIterator
+from chainer import optimizers
+from chainer import training
+from chainer.training import extensions
+from chainer.dataset import convert
+
+from become_yukarin.config import create_from_json
+from become_yukarin.dataset import create as create_dataset
+from become_yukarin.model import create as create_model
+from become_yukarin.loss import Loss
+
+from functools import partial
+
+parser = argparse.ArgumentParser()
+parser.add_argument('config_json_path', type=Path)
+arguments = parser.parse_args()
+
+config = create_from_json(arguments.config_json_path)
+config.train.output.mkdir(exist_ok=True)
+config.save_as_json((config.train.output / 'config.json').absolute())
+
+# model
+predictor = create_model(config.model)
+model = Loss(config.loss, predictor=predictor)
+
+# dataset
+dataset = create_dataset(config.dataset)
+train_iter = MultiprocessIterator(dataset['train'], config.train.batchsize)
+test_iter = MultiprocessIterator(dataset['test'], config.train.batchsize, repeat=False, shuffle=False)
+train_eval_iter = MultiprocessIterator(dataset['train_eval'], config.train.batchsize, repeat=False, shuffle=False)
+
+# optimizer
+optimizer = optimizers.Adam()
+optimizer.setup(model)
+
+# trainer
+trigger_best = training.triggers.MinValueTrigger('test/main/loss', (config.train.snapshot_iteration, 'iteration'))
+
+converter = partial(convert.concat_examples, padding=0)
+updater = training.StandardUpdater(train_iter, optimizer, device=config.train.gpu, converter=converter)
+trainer = training.Trainer(updater, out=config.train.output)
+
+ext = extensions.Evaluator(test_iter, model, converter, device=config.train.gpu)
+trainer.extend(ext, name='test', trigger=(config.train.log_iteration, 'iteration'))
+ext = extensions.Evaluator(train_eval_iter, model, converter, device=config.train.gpu)
+trainer.extend(ext, name='train', trigger=(config.train.log_iteration, 'iteration'))
+
+trainer.extend(extensions.dump_graph('main/loss', out_name='graph.dot'))
+
+ext = extensions.snapshot_object(predictor, filename='predictor_{.updater.iteration}.npz')
+trainer.extend(ext, trigger=trigger_best)
+
+trainer.extend(extensions.LogReport(trigger=(config.train.log_iteration, 'iteration'), log_name='log.txt'))
+
+if extensions.PlotReport.available():
+ trainer.extend(extensions.PlotReport(
+ y_keys=['main/loss', 'test/main/loss'],
+ x_key='iteration',
+ file_name='loss.png',
+ trigger=(config.train.log_iteration, 'iteration'),
+ ))
+
+trainer.run()