diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-13 07:50:22 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-13 20:50:25 +0900 |
| commit | 1436960f9dcd9572f7a68b41b9a3ed2de0bbad85 (patch) | |
| tree | 1153c5c1d4a245719dd919dfdd8c1913e13e7996 /train.py | |
| parent | 6b2bae905e59d3b8756c624e38a447786c2b9e9d (diff) | |
WIP
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 15 |
1 files changed, 8 insertions, 7 deletions
@@ -37,30 +37,31 @@ optimizer = optimizers.Adam() optimizer.setup(model) # trainer -trigger_best = training.triggers.MinValueTrigger('test/main/loss', (config.train.snapshot_iteration, 'iteration')) +trigger_log = (config.train.log_iteration, 'iteration') +trigger_snapshot = (config.train.snapshot_iteration, 'iteration') converter = partial(convert.concat_examples, padding=0) updater = training.StandardUpdater(train_iter, optimizer, device=config.train.gpu, converter=converter) trainer = training.Trainer(updater, out=config.train.output) ext = extensions.Evaluator(test_iter, model, converter, device=config.train.gpu) -trainer.extend(ext, name='test', trigger=(config.train.log_iteration, 'iteration')) +trainer.extend(ext, name='test', trigger=trigger_log) ext = extensions.Evaluator(train_eval_iter, model, converter, device=config.train.gpu) -trainer.extend(ext, name='train', trigger=(config.train.log_iteration, 'iteration')) +trainer.extend(ext, name='train', trigger=trigger_log) trainer.extend(extensions.dump_graph('main/loss', out_name='graph.dot')) ext = extensions.snapshot_object(predictor, filename='predictor_{.updater.iteration}.npz') -trainer.extend(ext, trigger=trigger_best) +trainer.extend(ext, trigger=trigger_snapshot) -trainer.extend(extensions.LogReport(trigger=(config.train.log_iteration, 'iteration'), log_name='log.txt')) +trainer.extend(extensions.LogReport(trigger=trigger_log, log_name='log.txt')) if extensions.PlotReport.available(): trainer.extend(extensions.PlotReport( - y_keys=['main/loss', 'test/main/loss'], + y_keys=['main/loss', 'test/main/loss', 'train/main/loss'], x_key='iteration', file_name='loss.png', - trigger=(config.train.log_iteration, 'iteration'), + trigger=trigger_log, )) trainer.run() |
