diff options
Diffstat (limited to 'options')
| -rw-r--r-- | options/base_options.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/options/base_options.py b/options/base_options.py index b5b92fb..275c8fc 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -1,7 +1,7 @@ import argparse import os from util import util - +import torch class BaseOptions(): def __init__(self): @@ -54,6 +54,10 @@ class BaseOptions(): id = int(str_id) if id >= 0: self.opt.gpu_ids.append(id) + + # set gpu ids + if len(self.opt.gpu_ids) > 0: + torch.cuda.set_device(self.opt.gpu_ids[0]) args = vars(self.opt) |
