summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2018-01-09 20:11:52 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2018-01-09 20:11:52 +0900
commitf8af475646b812804bdbaddc5eed7df715ca3b9e (patch)
tree4adc15a28c7337af140aba08466228a6599427c3
parent6b56d407a050c9c24fc0a0f6702bf5e9eee7450f (diff)
chainerUI
-rw-r--r--requirements.txt1
-rw-r--r--train.py8
2 files changed, 6 insertions, 3 deletions
diff --git a/requirements.txt b/requirements.txt
index 80fbb0d..19ca71a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,3 +7,4 @@ pyworld
fastdtw
nnmnkwii
matplotlib
+chainerui
diff --git a/train.py b/train.py
index f62d338..346ec4b 100644
--- a/train.py
+++ b/train.py
@@ -8,11 +8,12 @@ from chainer import training
from chainer.dataset import convert
from chainer.iterators import MultiprocessIterator
from chainer.training import extensions
+from chainerui.utils import save_args
from become_yukarin.config import create_from_json
from become_yukarin.dataset import create as create_dataset
-from become_yukarin.updater import Updater
from become_yukarin.model import create
+from become_yukarin.updater import Updater
parser = argparse.ArgumentParser()
parser.add_argument('config_json_path', type=Path)
@@ -74,12 +75,12 @@ trainer.extend(ext, name='test', trigger=trigger_log)
ext = extensions.Evaluator(train_eval_iter, models, converter, device=config.train.gpu, eval_func=updater.forward)
trainer.extend(ext, name='train', trigger=trigger_log)
-trainer.extend(extensions.dump_graph('predictor/loss', out_name='graph.dot'))
+trainer.extend(extensions.dump_graph('predictor/loss'))
ext = extensions.snapshot_object(predictor, filename='predictor_{.updater.iteration}.npz')
trainer.extend(ext, trigger=trigger_snapshot)
-trainer.extend(extensions.LogReport(trigger=trigger_log, log_name='log.txt'))
+trainer.extend(extensions.LogReport(trigger=trigger_log))
if extensions.PlotReport.available():
trainer.extend(extensions.PlotReport(
@@ -98,4 +99,5 @@ if extensions.PlotReport.available():
trigger=trigger_log,
))
+save_args(arguments, arguments.output)
trainer.run()