summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-21 06:37:41 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-21 06:37:41 +0900
commitd6af2a851644afe253b97461b35138011a479a95 (patch)
treebc94f2d6e6723ee3240032f901175d9501d512c2 /train.py
parent16b4e72fe6728e2e64d4c6357b7c73ac06868c1c (diff)
modify aligner
Diffstat (limited to 'train.py')
-rw-r--r--train.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/train.py b/train.py
index 08ef2d9..a9f4e79 100644
--- a/train.py
+++ b/train.py
@@ -2,6 +2,7 @@ import argparse
from functools import partial
from pathlib import Path
+from chainer import cuda
from chainer import optimizers
from chainer import training
from chainer.dataset import convert
@@ -24,6 +25,8 @@ arguments.output.mkdir(exist_ok=True)
config.save_as_json((arguments.output / 'config.json').absolute())
# model
+if config.train.gpu >= 0:
+ cuda.get_device_from_id(config.train.gpu).use()
predictor = create_predictor(config.model)
aligner = create_aligner(config.model)
model = Loss(config.loss, predictor=predictor, aligner=aligner)