From f8af475646b812804bdbaddc5eed7df715ca3b9e Mon Sep 17 00:00:00 2001 From: Hiroshiba Kazuyuki Date: Tue, 9 Jan 2018 20:11:52 +0900 Subject: chainerUI --- train.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index f62d338..346ec4b 100644 --- a/train.py +++ b/train.py @@ -8,11 +8,12 @@ from chainer import training from chainer.dataset import convert from chainer.iterators import MultiprocessIterator from chainer.training import extensions +from chainerui.utils import save_args from become_yukarin.config import create_from_json from become_yukarin.dataset import create as create_dataset -from become_yukarin.updater import Updater from become_yukarin.model import create +from become_yukarin.updater import Updater parser = argparse.ArgumentParser() parser.add_argument('config_json_path', type=Path) @@ -74,12 +75,12 @@ trainer.extend(ext, name='test', trigger=trigger_log) ext = extensions.Evaluator(train_eval_iter, models, converter, device=config.train.gpu, eval_func=updater.forward) trainer.extend(ext, name='train', trigger=trigger_log) -trainer.extend(extensions.dump_graph('predictor/loss', out_name='graph.dot')) +trainer.extend(extensions.dump_graph('predictor/loss')) ext = extensions.snapshot_object(predictor, filename='predictor_{.updater.iteration}.npz') trainer.extend(ext, trigger=trigger_snapshot) -trainer.extend(extensions.LogReport(trigger=trigger_log, log_name='log.txt')) +trainer.extend(extensions.LogReport(trigger=trigger_log)) if extensions.PlotReport.available(): trainer.extend(extensions.PlotReport( @@ -98,4 +99,5 @@ if extensions.PlotReport.available(): trigger=trigger_log, )) +save_args(arguments, arguments.output) trainer.run() -- cgit v1.2.3-70-g09d2