summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2018-01-09 20:11:52 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2018-01-09 20:11:52 +0900
commitf8af475646b812804bdbaddc5eed7df715ca3b9e (patch)
tree4adc15a28c7337af140aba08466228a6599427c3 /train.py
parent6b56d407a050c9c24fc0a0f6702bf5e9eee7450f (diff)
chainerUI
Diffstat (limited to 'train.py')
-rw-r--r--train.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/train.py b/train.py
index f62d338..346ec4b 100644
--- a/train.py
+++ b/train.py
@@ -8,11 +8,12 @@ from chainer import training
from chainer.dataset import convert
from chainer.iterators import MultiprocessIterator
from chainer.training import extensions
+from chainerui.utils import save_args
from become_yukarin.config import create_from_json
from become_yukarin.dataset import create as create_dataset
-from become_yukarin.updater import Updater
from become_yukarin.model import create
+from become_yukarin.updater import Updater
parser = argparse.ArgumentParser()
parser.add_argument('config_json_path', type=Path)
@@ -74,12 +75,12 @@ trainer.extend(ext, name='test', trigger=trigger_log)
ext = extensions.Evaluator(train_eval_iter, models, converter, device=config.train.gpu, eval_func=updater.forward)
trainer.extend(ext, name='train', trigger=trigger_log)
-trainer.extend(extensions.dump_graph('predictor/loss', out_name='graph.dot'))
+trainer.extend(extensions.dump_graph('predictor/loss'))
ext = extensions.snapshot_object(predictor, filename='predictor_{.updater.iteration}.npz')
trainer.extend(ext, trigger=trigger_snapshot)
-trainer.extend(extensions.LogReport(trigger=trigger_log, log_name='log.txt'))
+trainer.extend(extensions.LogReport(trigger=trigger_log))
if extensions.PlotReport.available():
trainer.extend(extensions.PlotReport(
@@ -98,4 +99,5 @@ if extensions.PlotReport.available():
trigger=trigger_log,
))
+save_args(arguments, arguments.output)
trainer.run()