summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-15 02:27:33 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-15 02:27:33 +0900
commit4003ae8e457905070b789b75c5972ca93cc5756b (patch)
treed79bd6ba959ee673adcc89e8d5c69c0ec8cf0d93 /train.py
parenta4f60ab4cd44d1fc89e83bb662fe430e3824d0dc (diff)
little modify
Diffstat (limited to 'train.py')
-rw-r--r--train.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/train.py b/train.py
index b1a2213..27fd1fb 100644
--- a/train.py
+++ b/train.py
@@ -16,11 +16,12 @@ from functools import partial
parser = argparse.ArgumentParser()
parser.add_argument('config_json_path', type=Path)
+parser.add_argument('output', type=Path)
arguments = parser.parse_args()
config = create_from_json(arguments.config_json_path)
-config.train.output.mkdir(exist_ok=True)
-config.save_as_json((config.train.output / 'config.json').absolute())
+arguments.output.mkdir(exist_ok=True)
+config.save_as_json((arguments.output / 'config.json').absolute())
# model
predictor = create_model(config.model)
@@ -42,7 +43,7 @@ trigger_snapshot = (config.train.snapshot_iteration, 'iteration')
converter = partial(convert.concat_examples, padding=0)
updater = training.StandardUpdater(train_iter, optimizer, device=config.train.gpu, converter=converter)
-trainer = training.Trainer(updater, out=config.train.output)
+trainer = training.Trainer(updater, out=arguments.output)
ext = extensions.Evaluator(test_iter, model, converter, device=config.train.gpu)
trainer.extend(ext, name='test', trigger=trigger_log)