From 1436960f9dcd9572f7a68b41b9a3ed2de0bbad85 Mon Sep 17 00:00:00 2001 From: Hiroshiba Kazuyuki Date: Mon, 13 Nov 2017 07:50:22 +0900 Subject: WIP --- train.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index eb0833d..b1a2213 100644 --- a/train.py +++ b/train.py @@ -37,30 +37,31 @@ optimizer = optimizers.Adam() optimizer.setup(model) # trainer -trigger_best = training.triggers.MinValueTrigger('test/main/loss', (config.train.snapshot_iteration, 'iteration')) +trigger_log = (config.train.log_iteration, 'iteration') +trigger_snapshot = (config.train.snapshot_iteration, 'iteration') converter = partial(convert.concat_examples, padding=0) updater = training.StandardUpdater(train_iter, optimizer, device=config.train.gpu, converter=converter) trainer = training.Trainer(updater, out=config.train.output) ext = extensions.Evaluator(test_iter, model, converter, device=config.train.gpu) -trainer.extend(ext, name='test', trigger=(config.train.log_iteration, 'iteration')) +trainer.extend(ext, name='test', trigger=trigger_log) ext = extensions.Evaluator(train_eval_iter, model, converter, device=config.train.gpu) -trainer.extend(ext, name='train', trigger=(config.train.log_iteration, 'iteration')) +trainer.extend(ext, name='train', trigger=trigger_log) trainer.extend(extensions.dump_graph('main/loss', out_name='graph.dot')) ext = extensions.snapshot_object(predictor, filename='predictor_{.updater.iteration}.npz') -trainer.extend(ext, trigger=trigger_best) +trainer.extend(ext, trigger=trigger_snapshot) -trainer.extend(extensions.LogReport(trigger=(config.train.log_iteration, 'iteration'), log_name='log.txt')) +trainer.extend(extensions.LogReport(trigger=trigger_log, log_name='log.txt')) if extensions.PlotReport.available(): trainer.extend(extensions.PlotReport( - y_keys=['main/loss', 'test/main/loss'], + y_keys=['main/loss', 'test/main/loss', 'train/main/loss'], x_key='iteration', file_name='loss.png', - trigger=(config.train.log_iteration, 'iteration'), + trigger=trigger_log, )) trainer.run() -- cgit v1.2.3-70-g09d2