diff options
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 3 |
1 files changed, 3 insertions, 0 deletions
@@ -2,6 +2,7 @@ import argparse from functools import partial from pathlib import Path +from chainer import cuda from chainer import optimizers from chainer import training from chainer.dataset import convert @@ -24,6 +25,8 @@ arguments.output.mkdir(exist_ok=True) config.save_as_json((arguments.output / 'config.json').absolute()) # model +if config.train.gpu >= 0: + cuda.get_device_from_id(config.train.gpu).use() predictor = create_predictor(config.model) aligner = create_aligner(config.model) model = Loss(config.loss, predictor=predictor, aligner=aligner) |
