diff options
Diffstat (limited to 'models/networks.py')
| -rw-r--r-- | models/networks.py | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/models/networks.py b/models/networks.py index 60e1777..b0f3b11 100644 --- a/models/networks.py +++ b/models/networks.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn from torch.autograd import Variable -from pdb import set_trace as st import numpy as np ############################################################################### @@ -13,7 +12,7 @@ def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) - elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1: + elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNormalization') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) @@ -162,7 +161,7 @@ class ResnetGenerator(nn.Module): self.model = nn.Sequential(*model) def forward(self, input): - if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) @@ -222,7 +221,7 @@ class UnetGenerator(nn.Module): self.model = unet_block def forward(self, input): - if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) @@ -323,7 +322,7 @@ class NLayerDiscriminator(nn.Module): self.model = nn.Sequential(*sequence) def forward(self, input): - if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) |
