summaryrefslogtreecommitdiff
path: root/models/networks.py
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-04-22 20:25:09 -0700
committerjunyanz <junyanz@berkeley.edu>2017-04-22 20:25:09 -0700
commit9137b1146f9b0edb14d7c4421ed7abed1583b025 (patch)
treef643838e27813345408b26dd55ddd36454a9f249 /models/networks.py
parent6c347282993d2e2db91b376d3113efa3774c3a22 (diff)
add comment for python 2/3
Diffstat (limited to 'models/networks.py')
-rw-r--r--models/networks.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/models/networks.py b/models/networks.py
index 60e1777..2aea150 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -13,7 +13,7 @@ def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
- elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1:
+ elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNormalization') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
@@ -162,7 +162,7 @@ class ResnetGenerator(nn.Module):
self.model = nn.Sequential(*model)
def forward(self, input):
- if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
@@ -222,7 +222,7 @@ class UnetGenerator(nn.Module):
self.model = unet_block
def forward(self, input):
- if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
@@ -323,7 +323,7 @@ class NLayerDiscriminator(nn.Module):
self.model = nn.Sequential(*sequence)
def forward(self, input):
- if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)