summaryrefslogtreecommitdiff
path: root/models/networks.py
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-04-18 03:57:18 -0700
committerjunyanz <junyanz@berkeley.edu>2017-04-18 03:57:18 -0700
commitc039c2596d61f70382e78a6d16203ce820572585 (patch)
tree8ca89cd0321d6876ef04357f714f8ed533f6a348 /models/networks.py
parentefef2906b0a3f4fd265823ff4b4b99ccebeb6d05 (diff)
fix assert bug
Diffstat (limited to 'models/networks.py')
-rw-r--r--models/networks.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/models/networks.py b/models/networks.py
index 9cb6222..a7b6860 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -26,8 +26,9 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]):
norm_layer = InstanceNormalization
else:
print('normalization layer [%s] is not found' % norm)
+ if use_gpu:
+ assert(torch.cuda.is_available())
- assert(torch.cuda.is_available() == use_gpu)
if which_model_netG == 'resnet_9blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids)
elif which_model_netG == 'resnet_6blocks':
@@ -48,7 +49,8 @@ def define_D(input_nc, ndf, which_model_netD,
n_layers_D=3, use_sigmoid=False, gpu_ids=[]):
netD = None
use_gpu = len(gpu_ids) > 0
- assert(torch.cuda.is_available() == use_gpu)
+ if use_gpu:
+ assert(torch.cuda.is_available())
if which_model_netD == 'basic':
netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'n_layers':
@@ -193,7 +195,7 @@ class ResnetBlock(nn.Module):
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
-# if |num_downs| == 7, image of size 128x128 will become of size 1x1
+# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
@@ -232,7 +234,7 @@ class UnetGenerator(nn.Module):
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc,
- submodule=None, outermost=False, innermost=False):
+ submodule=None, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outer_nc = outer_nc
self.inner_nc = inner_nc
@@ -268,12 +270,12 @@ class UnetSkipConnectionBlock(nn.Module):
self.model = nn.Sequential(*model)
- def forward(self, x):
+ def forward(self, x):
#print(self.outer_nc, self.inner_nc, self.innermost)
#print(x.size())
#print(self.model(x).size())
return torch.cat([self.model(x), x], 1)
-
+
# Defines the PatchGAN discriminator with the specified arguments.