diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-04-18 03:41:57 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-04-18 03:41:57 -0700 |
| commit | 56db5630e7ec32de85fde7be2ba48d71c51ec3b9 (patch) | |
| tree | fd45641f9a124c9c957f001affb352fd143b9620 /models/networks.py | |
| parent | c99ce7c4e781712e0252c6127ad1a4e8021cc489 (diff) | |
fix assert bug
Diffstat (limited to 'models/networks.py')
| -rw-r--r-- | models/networks.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/models/networks.py b/models/networks.py index d41bd0e..edbe972 100644 --- a/models/networks.py +++ b/models/networks.py @@ -26,8 +26,9 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]): norm_layer = InstanceNormalization else: print('normalization layer [%s] is not found' % norm) - - assert(torch.cuda.is_available() == use_gpu) + if use_gpu: + assert(torch.cuda.is_available()) + if which_model_netG == 'resnet_9blocks': netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids) elif which_model_netG == 'resnet_6blocks': @@ -46,7 +47,9 @@ def define_D(input_nc, ndf, which_model_netD, n_layers_D=3, use_sigmoid=False, gpu_ids=[]): netD = None use_gpu = len(gpu_ids) > 0 - assert(torch.cuda.is_available() == use_gpu) + if use_gpu: + assert(torch.cuda.is_available()) + if which_model_netD == 'basic': netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'n_layers': |
