diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-04-22 20:25:09 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-04-22 20:25:09 -0700 |
| commit | 9137b1146f9b0edb14d7c4421ed7abed1583b025 (patch) | |
| tree | f643838e27813345408b26dd55ddd36454a9f249 /models/networks.py | |
| parent | 6c347282993d2e2db91b376d3113efa3774c3a22 (diff) | |
add comment for python 2/3
Diffstat (limited to 'models/networks.py')
| -rw-r--r-- | models/networks.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/models/networks.py b/models/networks.py index 60e1777..2aea150 100644 --- a/models/networks.py +++ b/models/networks.py @@ -13,7 +13,7 @@ def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) - elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1: + elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNormalization') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) @@ -162,7 +162,7 @@ class ResnetGenerator(nn.Module): self.model = nn.Sequential(*model) def forward(self, input): - if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) @@ -222,7 +222,7 @@ class UnetGenerator(nn.Module): self.model = unet_block def forward(self, input): - if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) @@ -323,7 +323,7 @@ class NLayerDiscriminator(nn.Module): self.model = nn.Sequential(*sequence) def forward(self, input): - if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) |
