summaryrefslogtreecommitdiff
path: root/models/networks.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/networks.py')
-rw-r--r--models/networks.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/models/networks.py b/models/networks.py
index d41bd0e..edbe972 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -26,8 +26,9 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]):
norm_layer = InstanceNormalization
else:
print('normalization layer [%s] is not found' % norm)
-
- assert(torch.cuda.is_available() == use_gpu)
+ if use_gpu:
+ assert(torch.cuda.is_available())
+
if which_model_netG == 'resnet_9blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids)
elif which_model_netG == 'resnet_6blocks':
@@ -46,7 +47,9 @@ def define_D(input_nc, ndf, which_model_netD,
n_layers_D=3, use_sigmoid=False, gpu_ids=[]):
netD = None
use_gpu = len(gpu_ids) > 0
- assert(torch.cuda.is_available() == use_gpu)
+ if use_gpu:
+ assert(torch.cuda.is_available())
+
if which_model_netD == 'basic':
netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'n_layers':