diff options
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 8 |
1 files changed, 5 insertions, 3 deletions
@@ -8,11 +8,12 @@ from chainer import training from chainer.dataset import convert from chainer.iterators import MultiprocessIterator from chainer.training import extensions +from chainerui.utils import save_args from become_yukarin.config import create_from_json from become_yukarin.dataset import create as create_dataset -from become_yukarin.updater import Updater from become_yukarin.model import create +from become_yukarin.updater import Updater parser = argparse.ArgumentParser() parser.add_argument('config_json_path', type=Path) @@ -74,12 +75,12 @@ trainer.extend(ext, name='test', trigger=trigger_log) ext = extensions.Evaluator(train_eval_iter, models, converter, device=config.train.gpu, eval_func=updater.forward) trainer.extend(ext, name='train', trigger=trigger_log) -trainer.extend(extensions.dump_graph('predictor/loss', out_name='graph.dot')) +trainer.extend(extensions.dump_graph('predictor/loss')) ext = extensions.snapshot_object(predictor, filename='predictor_{.updater.iteration}.npz') trainer.extend(ext, trigger=trigger_snapshot) -trainer.extend(extensions.LogReport(trigger=trigger_log, log_name='log.txt')) +trainer.extend(extensions.LogReport(trigger=trigger_log)) if extensions.PlotReport.available(): trainer.extend(extensions.PlotReport( @@ -98,4 +99,5 @@ if extensions.PlotReport.available(): trigger=trigger_log, )) +save_args(arguments, arguments.output) trainer.run() |
