summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-13 07:50:22 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-13 20:50:25 +0900
commit1436960f9dcd9572f7a68b41b9a3ed2de0bbad85 (patch)
tree1153c5c1d4a245719dd919dfdd8c1913e13e7996 /train.py
parent6b2bae905e59d3b8756c624e38a447786c2b9e9d (diff)
WIP
Diffstat (limited to 'train.py')
-rw-r--r--train.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/train.py b/train.py
index eb0833d..b1a2213 100644
--- a/train.py
+++ b/train.py
@@ -37,30 +37,31 @@ optimizer = optimizers.Adam()
optimizer.setup(model)
# trainer
-trigger_best = training.triggers.MinValueTrigger('test/main/loss', (config.train.snapshot_iteration, 'iteration'))
+trigger_log = (config.train.log_iteration, 'iteration')
+trigger_snapshot = (config.train.snapshot_iteration, 'iteration')
converter = partial(convert.concat_examples, padding=0)
updater = training.StandardUpdater(train_iter, optimizer, device=config.train.gpu, converter=converter)
trainer = training.Trainer(updater, out=config.train.output)
ext = extensions.Evaluator(test_iter, model, converter, device=config.train.gpu)
-trainer.extend(ext, name='test', trigger=(config.train.log_iteration, 'iteration'))
+trainer.extend(ext, name='test', trigger=trigger_log)
ext = extensions.Evaluator(train_eval_iter, model, converter, device=config.train.gpu)
-trainer.extend(ext, name='train', trigger=(config.train.log_iteration, 'iteration'))
+trainer.extend(ext, name='train', trigger=trigger_log)
trainer.extend(extensions.dump_graph('main/loss', out_name='graph.dot'))
ext = extensions.snapshot_object(predictor, filename='predictor_{.updater.iteration}.npz')
-trainer.extend(ext, trigger=trigger_best)
+trainer.extend(ext, trigger=trigger_snapshot)
-trainer.extend(extensions.LogReport(trigger=(config.train.log_iteration, 'iteration'), log_name='log.txt'))
+trainer.extend(extensions.LogReport(trigger=trigger_log, log_name='log.txt'))
if extensions.PlotReport.available():
trainer.extend(extensions.PlotReport(
- y_keys=['main/loss', 'test/main/loss'],
+ y_keys=['main/loss', 'test/main/loss', 'train/main/loss'],
x_key='iteration',
file_name='loss.png',
- trigger=(config.train.log_iteration, 'iteration'),
+ trigger=trigger_log,
))
trainer.run()