summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/train.py b/train.py
index 08ef2d9..a9f4e79 100644
--- a/train.py
+++ b/train.py
@@ -2,6 +2,7 @@ import argparse
from functools import partial
from pathlib import Path
+from chainer import cuda
from chainer import optimizers
from chainer import training
from chainer.dataset import convert
@@ -24,6 +25,8 @@ arguments.output.mkdir(exist_ok=True)
config.save_as_json((arguments.output / 'config.json').absolute())
# model
+if config.train.gpu >= 0:
+ cuda.get_device_from_id(config.train.gpu).use()
predictor = create_predictor(config.model)
aligner = create_aligner(config.model)
model = Loss(config.loss, predictor=predictor, aligner=aligner)