From 56db5630e7ec32de85fde7be2ba48d71c51ec3b9 Mon Sep 17 00:00:00 2001 From: junyanz Date: Tue, 18 Apr 2017 03:41:57 -0700 Subject: fix assert bug --- models/networks.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'models/networks.py') diff --git a/models/networks.py b/models/networks.py index d41bd0e..edbe972 100644 --- a/models/networks.py +++ b/models/networks.py @@ -26,8 +26,9 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]): norm_layer = InstanceNormalization else: print('normalization layer [%s] is not found' % norm) - - assert(torch.cuda.is_available() == use_gpu) + if use_gpu: + assert(torch.cuda.is_available()) + if which_model_netG == 'resnet_9blocks': netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids) elif which_model_netG == 'resnet_6blocks': @@ -46,7 +47,9 @@ def define_D(input_nc, ndf, which_model_netD, n_layers_D=3, use_sigmoid=False, gpu_ids=[]): netD = None use_gpu = len(gpu_ids) > 0 - assert(torch.cuda.is_available() == use_gpu) + if use_gpu: + assert(torch.cuda.is_available()) + if which_model_netD == 'basic': netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'n_layers': -- cgit v1.2.3-70-g09d2