From 9137b1146f9b0edb14d7c4421ed7abed1583b025 Mon Sep 17 00:00:00 2001 From: junyanz Date: Sat, 22 Apr 2017 20:25:09 -0700 Subject: add comment for python 2/3 --- models/networks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'models/networks.py') diff --git a/models/networks.py b/models/networks.py index 60e1777..2aea150 100644 --- a/models/networks.py +++ b/models/networks.py @@ -13,7 +13,7 @@ def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) - elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1: + elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNormalization') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) @@ -162,7 +162,7 @@ class ResnetGenerator(nn.Module): self.model = nn.Sequential(*model) def forward(self, input): - if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) @@ -222,7 +222,7 @@ class UnetGenerator(nn.Module): self.model = unet_block def forward(self, input): - if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) @@ -323,7 +323,7 @@ class NLayerDiscriminator(nn.Module): self.model = nn.Sequential(*sequence) def forward(self, input): - if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) -- cgit v1.2.3-70-g09d2