diff options
| -rw-r--r-- | data/aligned_data_loader.py | 17 | ||||
| -rw-r--r-- | data/unaligned_data_loader.py | 16 | ||||
| -rw-r--r-- | models/base_model.py | 1 | ||||
| -rw-r--r-- | models/cycle_gan_model.py | 3 | ||||
| -rw-r--r-- | models/models.py | 9 | ||||
| -rw-r--r-- | models/networks.py | 9 | ||||
| -rw-r--r-- | models/one_direction_test_model.py | 51 | ||||
| -rw-r--r-- | models/pix2pix_model.py | 5 | ||||
| -rw-r--r-- | options/base_options.py | 4 | ||||
| -rw-r--r-- | options/train_options.py | 5 | ||||
| -rw-r--r-- | scripts/test_pix2pix.sh | 2 |
11 files changed, 91 insertions, 31 deletions
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py index bea3531..a1efde8 100644 --- a/data/aligned_data_loader.py +++ b/data/aligned_data_loader.py @@ -4,19 +4,26 @@ import torchvision.transforms as transforms from data.base_data_loader import BaseDataLoader from data.image_folder import ImageFolder from pdb import set_trace as st +# pip install future --upgrade from builtins import object class PairedData(object): - def __init__(self, data_loader, fineSize): + def __init__(self, data_loader, fineSize, max_dataset_size): self.data_loader = data_loader self.fineSize = fineSize + self.max_dataset_size = max_dataset_size # st() def __iter__(self): self.data_loader_iter = iter(self.data_loader) + self.iter = 0 return self def __next__(self): + self.iter += 1 + if self.iter > self.max_dataset_size: + raise StopIteration + AB, AB_paths = next(self.data_loader_iter) w_total = AB.size(3) w = int(w_total / 2) @@ -24,7 +31,6 @@ class PairedData(object): w_offset = random.randint(0, max(0, w - self.fineSize - 1)) h_offset = random.randint(0, max(0, h - self.fineSize - 1)) - A = AB[:, :, h_offset:h_offset + self.fineSize, w_offset:w_offset + self.fineSize] B = AB[:, :, h_offset:h_offset + self.fineSize, @@ -39,8 +45,7 @@ class AlignedDataLoader(BaseDataLoader): self.fineSize = opt.fineSize transform = transforms.Compose([ # TODO: Scale - #transforms.Scale((opt.loadSize * 2, opt.loadSize)), - #transforms.CenterCrop(opt.fineSize), + transforms.Scale(opt.loadSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) @@ -55,7 +60,7 @@ class AlignedDataLoader(BaseDataLoader): num_workers=int(self.opt.nThreads)) self.dataset = dataset - self.paired_data = PairedData(data_loader, opt.fineSize) + self.paired_data = PairedData(data_loader, opt.fineSize, opt.max_dataset_size) def name(self): return 'AlignedDataLoader' @@ -64,4 +69,4 @@ class AlignedDataLoader(BaseDataLoader): return self.paired_data def __len__(self): - return len(self.dataset) + return min(len(self.dataset), self.opt.max_dataset_size) diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py index 4a06510..77f9274 100644 --- a/data/unaligned_data_loader.py +++ b/data/unaligned_data_loader.py @@ -2,21 +2,24 @@ import torch.utils.data import torchvision.transforms as transforms from data.base_data_loader import BaseDataLoader from data.image_folder import ImageFolder +# pip install future --upgrade from builtins import object from pdb import set_trace as st class PairedData(object): - def __init__(self, data_loader_A, data_loader_B): + def __init__(self, data_loader_A, data_loader_B, max_dataset_size): self.data_loader_A = data_loader_A self.data_loader_B = data_loader_B self.stop_A = False self.stop_B = False + self.max_dataset_size = max_dataset_size def __iter__(self): self.stop_A = False self.stop_B = False self.data_loader_A_iter = iter(self.data_loader_A) self.data_loader_B_iter = iter(self.data_loader_B) + self.iter = 0 return self def __next__(self): @@ -29,20 +32,21 @@ class PairedData(object): self.stop_A = True self.data_loader_A_iter = iter(self.data_loader_A) A, A_paths = next(self.data_loader_A_iter) + try: B, B_paths = next(self.data_loader_B_iter) - except StopIteration: if B is None or B_paths is None: self.stop_B = True self.data_loader_B_iter = iter(self.data_loader_B) B, B_paths = next(self.data_loader_B_iter) - if self.stop_A and self.stop_B: + if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size: self.stop_A = False self.stop_B = False raise StopIteration() else: + self.iter += 1 return {'A': A, 'A_paths': A_paths, 'B': B, 'B_paths': B_paths} @@ -51,7 +55,7 @@ class UnalignedDataLoader(BaseDataLoader): BaseDataLoader.initialize(self, opt) transform = transforms.Compose([ transforms.Scale(opt.loadSize), - transforms.CenterCrop(opt.fineSize), + transforms.RandomCrop(opt.fineSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) @@ -75,7 +79,7 @@ class UnalignedDataLoader(BaseDataLoader): num_workers=int(self.opt.nThreads)) self.dataset_A = dataset_A self.dataset_B = dataset_B - self.paired_data = PairedData(data_loader_A, data_loader_B) + self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.max_dataset_size) def name(self): return 'UnalignedDataLoader' @@ -84,4 +88,4 @@ class UnalignedDataLoader(BaseDataLoader): return self.paired_data def __len__(self): - return max(len(self.dataset_A), len(self.dataset_B)) + return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.max_dataset_size) diff --git a/models/base_model.py b/models/base_model.py index ce18635..9b92bb4 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -1,6 +1,5 @@ import os import torch -from pdb import set_trace as st class BaseModel(): def name(self): diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index d361e47..451002d 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -2,7 +2,6 @@ import numpy as np import torch import os from collections import OrderedDict -from pdb import set_trace as st from torch.autograd import Variable import itertools import util.util as util @@ -72,7 +71,7 @@ class CycleGANModel(BaseModel): print('-----------------------------------------------') def set_input(self, input): - AtoB = self.opt.which_direction is 'AtoB' + AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) diff --git a/models/models.py b/models/models.py index 7e790d0..8fea4f4 100644 --- a/models/models.py +++ b/models/models.py @@ -4,12 +4,17 @@ def create_model(opt): print(opt.model) if opt.model == 'cycle_gan': from .cycle_gan_model import CycleGANModel - assert(opt.align_data == False) + #assert(opt.align_data == False) model = CycleGANModel() - if opt.model == 'pix2pix': + elif opt.model == 'pix2pix': from .pix2pix_model import Pix2PixModel assert(opt.align_data == True) model = Pix2PixModel() + elif opt.model == 'one_direction_test': + from .one_direction_test_model import OneDirectionTestModel + model = OneDirectionTestModel() + else: + raise ValueError("Model [%s] not recognized." % opt.model) model.initialize(opt) print("model [%s] was created" % (model.name())) return model diff --git a/models/networks.py b/models/networks.py index 60e1777..b0f3b11 100644 --- a/models/networks.py +++ b/models/networks.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn from torch.autograd import Variable -from pdb import set_trace as st import numpy as np ############################################################################### @@ -13,7 +12,7 @@ def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) - elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1: + elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNormalization') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) @@ -162,7 +161,7 @@ class ResnetGenerator(nn.Module): self.model = nn.Sequential(*model) def forward(self, input): - if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) @@ -222,7 +221,7 @@ class UnetGenerator(nn.Module): self.model = unet_block def forward(self, input): - if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) @@ -323,7 +322,7 @@ class NLayerDiscriminator(nn.Module): self.model = nn.Sequential(*sequence) def forward(self, input): - if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) diff --git a/models/one_direction_test_model.py b/models/one_direction_test_model.py new file mode 100644 index 0000000..d4f6442 --- /dev/null +++ b/models/one_direction_test_model.py @@ -0,0 +1,51 @@ +from torch.autograd import Variable +from collections import OrderedDict +import util.util as util +from .base_model import BaseModel +from . import networks + + +class OneDirectionTestModel(BaseModel): + def name(self): + return 'OneDirectionTestModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + nb = opt.batchSize + size = opt.fineSize + self.input_A = self.Tensor(nb, opt.input_nc, size, size) + + assert(not self.isTrain) + self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, + opt.ngf, opt.which_model_netG, + opt.norm, opt.use_dropout, + self.gpu_ids) + which_epoch = opt.which_epoch + #AtoB = self.opt.which_direction == 'AtoB' + #which_network = 'G_A' if AtoB else 'G_B' + self.load_network(self.netG_A, 'G', which_epoch) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG_A) + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction == 'AtoB' + input_A = input['A' if AtoB else 'B'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def test(self): + self.real_A = Variable(self.input_A) + self.fake_B = self.netG_A.forward(self.real_A) + + #get image paths + def get_image_paths(self): + return self.image_paths + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) + diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 0e02ebf..34e0bac 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -2,13 +2,11 @@ import numpy as np import torch import os from collections import OrderedDict -from pdb import set_trace as st from torch.autograd import Variable import util.util as util from util.image_pool import ImagePool from .base_model import BaseModel from . import networks -from pdb import set_trace as st class Pix2PixModel(BaseModel): def name(self): @@ -55,7 +53,7 @@ class Pix2PixModel(BaseModel): print('-----------------------------------------------') def set_input(self, input): - AtoB = self.opt.which_direction is 'AtoB' + AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] self.input_A.resize_(input_A.size()).copy_(input_A) @@ -108,7 +106,6 @@ class Pix2PixModel(BaseModel): self.loss_G.backward() def optimize_parameters(self): - # st() self.forward() self.optimizer_D.zero_grad() diff --git a/options/base_options.py b/options/base_options.py index 4074746..bce0b9c 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -1,7 +1,7 @@ import argparse import os from util import util -from pdb import set_trace as st + class BaseOptions(): def __init__(self): self.parser = argparse.ArgumentParser() @@ -35,6 +35,8 @@ class BaseOptions(): self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') + self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + self.initialized = True def parse(self): diff --git a/options/train_options.py b/options/train_options.py index b241863..4b4eac3 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -10,10 +10,9 @@ class TrainOptions(BaseOptions): self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') - self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') - self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') + self.parser.add_argument('--niter', type=int, default=200, help='# of iter at starting learning rate') + self.parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero') self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') - self.parser.add_argument('--ntrain', type=int, default=float("inf"), help='# of examples per epoch.') self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') diff --git a/scripts/test_pix2pix.sh b/scripts/test_pix2pix.sh index d5c2960..0d19934 100644 --- a/scripts/test_pix2pix.sh +++ b/scripts/test_pix2pix.sh @@ -1 +1 @@ -python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --align_data +python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --align_data --use_dropout |
