diff options
Diffstat (limited to 'models/networks.py')
| -rw-r--r-- | models/networks.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/models/networks.py b/models/networks.py index d41bd0e..edbe972 100644 --- a/models/networks.py +++ b/models/networks.py @@ -26,8 +26,9 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]): norm_layer = InstanceNormalization else: print('normalization layer [%s] is not found' % norm) - - assert(torch.cuda.is_available() == use_gpu) + if use_gpu: + assert(torch.cuda.is_available()) + if which_model_netG == 'resnet_9blocks': netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids) elif which_model_netG == 'resnet_6blocks': @@ -46,7 +47,9 @@ def define_D(input_nc, ndf, which_model_netD, n_layers_D=3, use_sigmoid=False, gpu_ids=[]): netD = None use_gpu = len(gpu_ids) > 0 - assert(torch.cuda.is_available() == use_gpu) + if use_gpu: + assert(torch.cuda.is_available()) + if which_model_netD == 'basic': netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'n_layers': |
