summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-13 07:50:22 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-13 07:50:22 +0900
commite659b95dba4a071ed2669fc2143c5fd644ca44ca (patch)
treeaa2b9f6e1939d63702a06ba7964b9cc9cb3d667f /train.py
parent6b2bae905e59d3b8756c624e38a447786c2b9e9d (diff)
WIP
Diffstat (limited to 'train.py')
-rw-r--r--train.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/train.py b/train.py
index eb0833d..eeb77d2 100644
--- a/train.py
+++ b/train.py
@@ -37,30 +37,31 @@ optimizer = optimizers.Adam()
optimizer.setup(model)
# trainer
-trigger_best = training.triggers.MinValueTrigger('test/main/loss', (config.train.snapshot_iteration, 'iteration'))
+trigger_log = (config.train.log_iteration, 'iteration')
+trigger_snapshot = (config.train.snapshot_iteration, 'iteration')
converter = partial(convert.concat_examples, padding=0)
updater = training.StandardUpdater(train_iter, optimizer, device=config.train.gpu, converter=converter)
trainer = training.Trainer(updater, out=config.train.output)
ext = extensions.Evaluator(test_iter, model, converter, device=config.train.gpu)
-trainer.extend(ext, name='test', trigger=(config.train.log_iteration, 'iteration'))
+trainer.extend(ext, name='test', trigger=trigger_log)
ext = extensions.Evaluator(train_eval_iter, model, converter, device=config.train.gpu)
-trainer.extend(ext, name='train', trigger=(config.train.log_iteration, 'iteration'))
+trainer.extend(ext, name='train', trigger=trigger_log)
trainer.extend(extensions.dump_graph('main/loss', out_name='graph.dot'))
ext = extensions.snapshot_object(predictor, filename='predictor_{.updater.iteration}.npz')
-trainer.extend(ext, trigger=trigger_best)
+trainer.extend(ext, trigger=trigger_snapshot)
-trainer.extend(extensions.LogReport(trigger=(config.train.log_iteration, 'iteration'), log_name='log.txt'))
+trainer.extend(extensions.LogReport(trigger=trigger_log, log_name='log.txt'))
if extensions.PlotReport.available():
trainer.extend(extensions.PlotReport(
y_keys=['main/loss', 'test/main/loss'],
x_key='iteration',
file_name='loss.png',
- trigger=(config.train.log_iteration, 'iteration'),
+ trigger=trigger_log,
))
trainer.run()