summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data/aligned_data_loader.py1
-rw-r--r--data/unaligned_data_loader.py1
-rw-r--r--models/networks.py8
-rw-r--r--options/train_options.py4
4 files changed, 8 insertions, 6 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py
index bea3531..c5305fc 100644
--- a/data/aligned_data_loader.py
+++ b/data/aligned_data_loader.py
@@ -4,6 +4,7 @@ import torchvision.transforms as transforms
from data.base_data_loader import BaseDataLoader
from data.image_folder import ImageFolder
from pdb import set_trace as st
+# pip install future --upgrade
from builtins import object
class PairedData(object):
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
index 4a06510..f37f6ce 100644
--- a/data/unaligned_data_loader.py
+++ b/data/unaligned_data_loader.py
@@ -2,6 +2,7 @@ import torch.utils.data
import torchvision.transforms as transforms
from data.base_data_loader import BaseDataLoader
from data.image_folder import ImageFolder
+# pip install future --upgrade
from builtins import object
from pdb import set_trace as st
diff --git a/models/networks.py b/models/networks.py
index 60e1777..2aea150 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -13,7 +13,7 @@ def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
- elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1:
+ elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNormalization') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
@@ -162,7 +162,7 @@ class ResnetGenerator(nn.Module):
self.model = nn.Sequential(*model)
def forward(self, input):
- if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
@@ -222,7 +222,7 @@ class UnetGenerator(nn.Module):
self.model = unet_block
def forward(self, input):
- if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
@@ -323,7 +323,7 @@ class NLayerDiscriminator(nn.Module):
self.model = nn.Sequential(*sequence)
def forward(self, input):
- if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
diff --git a/options/train_options.py b/options/train_options.py
index b241863..981126f 100644
--- a/options/train_options.py
+++ b/options/train_options.py
@@ -10,8 +10,8 @@ class TrainOptions(BaseOptions):
self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
- self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
- self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
+ self.parser.add_argument('--niter', type=int, default=200, help='# of iter at starting learning rate')
+ self.parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
self.parser.add_argument('--ntrain', type=int, default=float("inf"), help='# of examples per epoch.')
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')