summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--data/aligned_data_loader.py17
-rw-r--r--data/unaligned_data_loader.py16
-rw-r--r--models/base_model.py1
-rw-r--r--models/cycle_gan_model.py3
-rw-r--r--models/models.py9
-rw-r--r--models/networks.py9
-rw-r--r--models/one_direction_test_model.py51
-rw-r--r--models/pix2pix_model.py5
-rw-r--r--options/base_options.py4
-rw-r--r--options/train_options.py5
-rw-r--r--scripts/test_pix2pix.sh2
11 files changed, 91 insertions, 31 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py
index bea3531..a1efde8 100644
--- a/data/aligned_data_loader.py
+++ b/data/aligned_data_loader.py
@@ -4,19 +4,26 @@ import torchvision.transforms as transforms
from data.base_data_loader import BaseDataLoader
from data.image_folder import ImageFolder
from pdb import set_trace as st
+# pip install future --upgrade
from builtins import object
class PairedData(object):
- def __init__(self, data_loader, fineSize):
+ def __init__(self, data_loader, fineSize, max_dataset_size):
self.data_loader = data_loader
self.fineSize = fineSize
+ self.max_dataset_size = max_dataset_size
# st()
def __iter__(self):
self.data_loader_iter = iter(self.data_loader)
+ self.iter = 0
return self
def __next__(self):
+ self.iter += 1
+ if self.iter > self.max_dataset_size:
+ raise StopIteration
+
AB, AB_paths = next(self.data_loader_iter)
w_total = AB.size(3)
w = int(w_total / 2)
@@ -24,7 +31,6 @@ class PairedData(object):
w_offset = random.randint(0, max(0, w - self.fineSize - 1))
h_offset = random.randint(0, max(0, h - self.fineSize - 1))
-
A = AB[:, :, h_offset:h_offset + self.fineSize,
w_offset:w_offset + self.fineSize]
B = AB[:, :, h_offset:h_offset + self.fineSize,
@@ -39,8 +45,7 @@ class AlignedDataLoader(BaseDataLoader):
self.fineSize = opt.fineSize
transform = transforms.Compose([
# TODO: Scale
- #transforms.Scale((opt.loadSize * 2, opt.loadSize)),
- #transforms.CenterCrop(opt.fineSize),
+ transforms.Scale(opt.loadSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))])
@@ -55,7 +60,7 @@ class AlignedDataLoader(BaseDataLoader):
num_workers=int(self.opt.nThreads))
self.dataset = dataset
- self.paired_data = PairedData(data_loader, opt.fineSize)
+ self.paired_data = PairedData(data_loader, opt.fineSize, opt.max_dataset_size)
def name(self):
return 'AlignedDataLoader'
@@ -64,4 +69,4 @@ class AlignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return len(self.dataset)
+ return min(len(self.dataset), self.opt.max_dataset_size)
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
index 4a06510..77f9274 100644
--- a/data/unaligned_data_loader.py
+++ b/data/unaligned_data_loader.py
@@ -2,21 +2,24 @@ import torch.utils.data
import torchvision.transforms as transforms
from data.base_data_loader import BaseDataLoader
from data.image_folder import ImageFolder
+# pip install future --upgrade
from builtins import object
from pdb import set_trace as st
class PairedData(object):
- def __init__(self, data_loader_A, data_loader_B):
+ def __init__(self, data_loader_A, data_loader_B, max_dataset_size):
self.data_loader_A = data_loader_A
self.data_loader_B = data_loader_B
self.stop_A = False
self.stop_B = False
+ self.max_dataset_size = max_dataset_size
def __iter__(self):
self.stop_A = False
self.stop_B = False
self.data_loader_A_iter = iter(self.data_loader_A)
self.data_loader_B_iter = iter(self.data_loader_B)
+ self.iter = 0
return self
def __next__(self):
@@ -29,20 +32,21 @@ class PairedData(object):
self.stop_A = True
self.data_loader_A_iter = iter(self.data_loader_A)
A, A_paths = next(self.data_loader_A_iter)
+
try:
B, B_paths = next(self.data_loader_B_iter)
-
except StopIteration:
if B is None or B_paths is None:
self.stop_B = True
self.data_loader_B_iter = iter(self.data_loader_B)
B, B_paths = next(self.data_loader_B_iter)
- if self.stop_A and self.stop_B:
+ if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size:
self.stop_A = False
self.stop_B = False
raise StopIteration()
else:
+ self.iter += 1
return {'A': A, 'A_paths': A_paths,
'B': B, 'B_paths': B_paths}
@@ -51,7 +55,7 @@ class UnalignedDataLoader(BaseDataLoader):
BaseDataLoader.initialize(self, opt)
transform = transforms.Compose([
transforms.Scale(opt.loadSize),
- transforms.CenterCrop(opt.fineSize),
+ transforms.RandomCrop(opt.fineSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))])
@@ -75,7 +79,7 @@ class UnalignedDataLoader(BaseDataLoader):
num_workers=int(self.opt.nThreads))
self.dataset_A = dataset_A
self.dataset_B = dataset_B
- self.paired_data = PairedData(data_loader_A, data_loader_B)
+ self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.max_dataset_size)
def name(self):
return 'UnalignedDataLoader'
@@ -84,4 +88,4 @@ class UnalignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return max(len(self.dataset_A), len(self.dataset_B))
+ return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.max_dataset_size)
diff --git a/models/base_model.py b/models/base_model.py
index ce18635..9b92bb4 100644
--- a/models/base_model.py
+++ b/models/base_model.py
@@ -1,6 +1,5 @@
import os
import torch
-from pdb import set_trace as st
class BaseModel():
def name(self):
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index d361e47..451002d 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -2,7 +2,6 @@ import numpy as np
import torch
import os
from collections import OrderedDict
-from pdb import set_trace as st
from torch.autograd import Variable
import itertools
import util.util as util
@@ -72,7 +71,7 @@ class CycleGANModel(BaseModel):
print('-----------------------------------------------')
def set_input(self, input):
- AtoB = self.opt.which_direction is 'AtoB'
+ AtoB = self.opt.which_direction == 'AtoB'
input_A = input['A' if AtoB else 'B']
input_B = input['B' if AtoB else 'A']
self.input_A.resize_(input_A.size()).copy_(input_A)
diff --git a/models/models.py b/models/models.py
index 7e790d0..8fea4f4 100644
--- a/models/models.py
+++ b/models/models.py
@@ -4,12 +4,17 @@ def create_model(opt):
print(opt.model)
if opt.model == 'cycle_gan':
from .cycle_gan_model import CycleGANModel
- assert(opt.align_data == False)
+ #assert(opt.align_data == False)
model = CycleGANModel()
- if opt.model == 'pix2pix':
+ elif opt.model == 'pix2pix':
from .pix2pix_model import Pix2PixModel
assert(opt.align_data == True)
model = Pix2PixModel()
+ elif opt.model == 'one_direction_test':
+ from .one_direction_test_model import OneDirectionTestModel
+ model = OneDirectionTestModel()
+ else:
+ raise ValueError("Model [%s] not recognized." % opt.model)
model.initialize(opt)
print("model [%s] was created" % (model.name()))
return model
diff --git a/models/networks.py b/models/networks.py
index 60e1777..b0f3b11 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -1,7 +1,6 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
-from pdb import set_trace as st
import numpy as np
###############################################################################
@@ -13,7 +12,7 @@ def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
- elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1:
+ elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNormalization') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
@@ -162,7 +161,7 @@ class ResnetGenerator(nn.Module):
self.model = nn.Sequential(*model)
def forward(self, input):
- if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
@@ -222,7 +221,7 @@ class UnetGenerator(nn.Module):
self.model = unet_block
def forward(self, input):
- if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
@@ -323,7 +322,7 @@ class NLayerDiscriminator(nn.Module):
self.model = nn.Sequential(*sequence)
def forward(self, input):
- if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
diff --git a/models/one_direction_test_model.py b/models/one_direction_test_model.py
new file mode 100644
index 0000000..d4f6442
--- /dev/null
+++ b/models/one_direction_test_model.py
@@ -0,0 +1,51 @@
+from torch.autograd import Variable
+from collections import OrderedDict
+import util.util as util
+from .base_model import BaseModel
+from . import networks
+
+
+class OneDirectionTestModel(BaseModel):
+ def name(self):
+ return 'OneDirectionTestModel'
+
+ def initialize(self, opt):
+ BaseModel.initialize(self, opt)
+
+ nb = opt.batchSize
+ size = opt.fineSize
+ self.input_A = self.Tensor(nb, opt.input_nc, size, size)
+
+ assert(not self.isTrain)
+ self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
+ opt.ngf, opt.which_model_netG,
+ opt.norm, opt.use_dropout,
+ self.gpu_ids)
+ which_epoch = opt.which_epoch
+ #AtoB = self.opt.which_direction == 'AtoB'
+ #which_network = 'G_A' if AtoB else 'G_B'
+ self.load_network(self.netG_A, 'G', which_epoch)
+
+ print('---------- Networks initialized -------------')
+ networks.print_network(self.netG_A)
+ print('-----------------------------------------------')
+
+ def set_input(self, input):
+ AtoB = self.opt.which_direction == 'AtoB'
+ input_A = input['A' if AtoB else 'B']
+ self.input_A.resize_(input_A.size()).copy_(input_A)
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
+
+ def test(self):
+ self.real_A = Variable(self.input_A)
+ self.fake_B = self.netG_A.forward(self.real_A)
+
+ #get image paths
+ def get_image_paths(self):
+ return self.image_paths
+
+ def get_current_visuals(self):
+ real_A = util.tensor2im(self.real_A.data)
+ fake_B = util.tensor2im(self.fake_B.data)
+ return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])
+
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
index 0e02ebf..34e0bac 100644
--- a/models/pix2pix_model.py
+++ b/models/pix2pix_model.py
@@ -2,13 +2,11 @@ import numpy as np
import torch
import os
from collections import OrderedDict
-from pdb import set_trace as st
from torch.autograd import Variable
import util.util as util
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
-from pdb import set_trace as st
class Pix2PixModel(BaseModel):
def name(self):
@@ -55,7 +53,7 @@ class Pix2PixModel(BaseModel):
print('-----------------------------------------------')
def set_input(self, input):
- AtoB = self.opt.which_direction is 'AtoB'
+ AtoB = self.opt.which_direction == 'AtoB'
input_A = input['A' if AtoB else 'B']
input_B = input['B' if AtoB else 'A']
self.input_A.resize_(input_A.size()).copy_(input_A)
@@ -108,7 +106,6 @@ class Pix2PixModel(BaseModel):
self.loss_G.backward()
def optimize_parameters(self):
- # st()
self.forward()
self.optimizer_D.zero_grad()
diff --git a/options/base_options.py b/options/base_options.py
index 4074746..bce0b9c 100644
--- a/options/base_options.py
+++ b/options/base_options.py
@@ -1,7 +1,7 @@
import argparse
import os
from util import util
-from pdb import set_trace as st
+
class BaseOptions():
def __init__(self):
self.parser = argparse.ArgumentParser()
@@ -35,6 +35,8 @@ class BaseOptions():
self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1')
self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
+ self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
+
self.initialized = True
def parse(self):
diff --git a/options/train_options.py b/options/train_options.py
index b241863..4b4eac3 100644
--- a/options/train_options.py
+++ b/options/train_options.py
@@ -10,10 +10,9 @@ class TrainOptions(BaseOptions):
self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
- self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
- self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
+ self.parser.add_argument('--niter', type=int, default=200, help='# of iter at starting learning rate')
+ self.parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
- self.parser.add_argument('--ntrain', type=int, default=float("inf"), help='# of examples per epoch.')
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
diff --git a/scripts/test_pix2pix.sh b/scripts/test_pix2pix.sh
index d5c2960..0d19934 100644
--- a/scripts/test_pix2pix.sh
+++ b/scripts/test_pix2pix.sh
@@ -1 +1 @@
-python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --align_data
+python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --align_data --use_dropout