diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-15 02:27:33 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-15 02:27:33 +0900 |
| commit | 4003ae8e457905070b789b75c5972ca93cc5756b (patch) | |
| tree | d79bd6ba959ee673adcc89e8d5c69c0ec8cf0d93 /train.py | |
| parent | a4f60ab4cd44d1fc89e83bb662fe430e3824d0dc (diff) | |
little modify
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 7 |
1 files changed, 4 insertions, 3 deletions
@@ -16,11 +16,12 @@ from functools import partial parser = argparse.ArgumentParser() parser.add_argument('config_json_path', type=Path) +parser.add_argument('output', type=Path) arguments = parser.parse_args() config = create_from_json(arguments.config_json_path) -config.train.output.mkdir(exist_ok=True) -config.save_as_json((config.train.output / 'config.json').absolute()) +arguments.output.mkdir(exist_ok=True) +config.save_as_json((arguments.output / 'config.json').absolute()) # model predictor = create_model(config.model) @@ -42,7 +43,7 @@ trigger_snapshot = (config.train.snapshot_iteration, 'iteration') converter = partial(convert.concat_examples, padding=0) updater = training.StandardUpdater(train_iter, optimizer, device=config.train.gpu, converter=converter) -trainer = training.Trainer(updater, out=config.train.output) +trainer = training.Trainer(updater, out=arguments.output) ext = extensions.Evaluator(test_iter, model, converter, device=config.train.gpu) trainer.extend(ext, name='test', trigger=trigger_log) |
