diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-04-18 03:57:18 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-04-18 03:57:18 -0700 |
| commit | c039c2596d61f70382e78a6d16203ce820572585 (patch) | |
| tree | 8ca89cd0321d6876ef04357f714f8ed533f6a348 /models/networks.py | |
| parent | efef2906b0a3f4fd265823ff4b4b99ccebeb6d05 (diff) | |
fix assert bug
Diffstat (limited to 'models/networks.py')
| -rw-r--r-- | models/networks.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/models/networks.py b/models/networks.py index 9cb6222..a7b6860 100644 --- a/models/networks.py +++ b/models/networks.py @@ -26,8 +26,9 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]): norm_layer = InstanceNormalization else: print('normalization layer [%s] is not found' % norm) + if use_gpu: + assert(torch.cuda.is_available()) - assert(torch.cuda.is_available() == use_gpu) if which_model_netG == 'resnet_9blocks': netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids) elif which_model_netG == 'resnet_6blocks': @@ -48,7 +49,8 @@ def define_D(input_nc, ndf, which_model_netD, n_layers_D=3, use_sigmoid=False, gpu_ids=[]): netD = None use_gpu = len(gpu_ids) > 0 - assert(torch.cuda.is_available() == use_gpu) + if use_gpu: + assert(torch.cuda.is_available()) if which_model_netD == 'basic': netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'n_layers': @@ -193,7 +195,7 @@ class ResnetBlock(nn.Module): # Defines the Unet generator. # |num_downs|: number of downsamplings in UNet. For example, -# if |num_downs| == 7, image of size 128x128 will become of size 1x1 +# if |num_downs| == 7, image of size 128x128 will become of size 1x1 # at the bottleneck class UnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, num_downs, ngf=64, @@ -232,7 +234,7 @@ class UnetGenerator(nn.Module): # |-- downsampling -- |submodule| -- upsampling --| class UnetSkipConnectionBlock(nn.Module): def __init__(self, outer_nc, inner_nc, - submodule=None, outermost=False, innermost=False): + submodule=None, outermost=False, innermost=False): super(UnetSkipConnectionBlock, self).__init__() self.outer_nc = outer_nc self.inner_nc = inner_nc @@ -268,12 +270,12 @@ class UnetSkipConnectionBlock(nn.Module): self.model = nn.Sequential(*model) - def forward(self, x): + def forward(self, x): #print(self.outer_nc, self.inner_nc, self.innermost) #print(x.size()) #print(self.model(x).size()) return torch.cat([self.model(x), x], 1) - + # Defines the PatchGAN discriminator with the specified arguments. |
