diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2018-01-09 20:11:52 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2018-01-09 20:11:52 +0900 |
| commit | f8af475646b812804bdbaddc5eed7df715ca3b9e (patch) | |
| tree | 4adc15a28c7337af140aba08466228a6599427c3 | |
| parent | 6b56d407a050c9c24fc0a0f6702bf5e9eee7450f (diff) | |
chainerUI
| -rw-r--r-- | requirements.txt | 1 | ||||
| -rw-r--r-- | train.py | 8 |
2 files changed, 6 insertions, 3 deletions
diff --git a/requirements.txt b/requirements.txt index 80fbb0d..19ca71a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ pyworld fastdtw nnmnkwii matplotlib +chainerui @@ -8,11 +8,12 @@ from chainer import training from chainer.dataset import convert from chainer.iterators import MultiprocessIterator from chainer.training import extensions +from chainerui.utils import save_args from become_yukarin.config import create_from_json from become_yukarin.dataset import create as create_dataset -from become_yukarin.updater import Updater from become_yukarin.model import create +from become_yukarin.updater import Updater parser = argparse.ArgumentParser() parser.add_argument('config_json_path', type=Path) @@ -74,12 +75,12 @@ trainer.extend(ext, name='test', trigger=trigger_log) ext = extensions.Evaluator(train_eval_iter, models, converter, device=config.train.gpu, eval_func=updater.forward) trainer.extend(ext, name='train', trigger=trigger_log) -trainer.extend(extensions.dump_graph('predictor/loss', out_name='graph.dot')) +trainer.extend(extensions.dump_graph('predictor/loss')) ext = extensions.snapshot_object(predictor, filename='predictor_{.updater.iteration}.npz') trainer.extend(ext, trigger=trigger_snapshot) -trainer.extend(extensions.LogReport(trigger=trigger_log, log_name='log.txt')) +trainer.extend(extensions.LogReport(trigger=trigger_log)) if extensions.PlotReport.available(): trainer.extend(extensions.PlotReport( @@ -98,4 +99,5 @@ if extensions.PlotReport.available(): trigger=trigger_log, )) +save_args(arguments, arguments.output) trainer.run() |
