diff options
| -rw-r--r-- | README.md | 10 | ||||
| -rw-r--r-- | data/aligned_data_loader.py | 87 | ||||
| -rw-r--r-- | data/aligned_dataset.py | 56 | ||||
| -rw-r--r-- | data/base_data_loader.py | 14 | ||||
| -rw-r--r-- | data/base_dataset.py | 12 | ||||
| -rw-r--r-- | data/custom_dataset_data_loader.py | 41 | ||||
| -rw-r--r-- | data/data_loader.py | 9 | ||||
| -rw-r--r-- | data/image_folder.py | 7 | ||||
| -rw-r--r-- | data/single_dataset.py | 47 | ||||
| -rw-r--r-- | data/unaligned_data_loader.py | 100 | ||||
| -rw-r--r-- | data/unaligned_dataset.py | 56 | ||||
| -rw-r--r-- | models/base_model.py | 1 | ||||
| -rw-r--r-- | models/cycle_gan_model.py | 30 | ||||
| -rw-r--r-- | models/models.py | 9 | ||||
| -rw-r--r-- | models/test_model.py (renamed from models/one_direction_test_model.py) | 5 | ||||
| -rw-r--r-- | options/base_options.py | 10 | ||||
| -rw-r--r-- | options/test_options.py | 1 | ||||
| -rw-r--r-- | options/train_options.py | 4 | ||||
| -rw-r--r-- | scripts/test_pix2pix.sh | 2 | ||||
| -rw-r--r-- | scripts/train_pix2pix.sh | 2 | ||||
| -rw-r--r-- | test.py | 1 |
21 files changed, 259 insertions, 245 deletions
@@ -44,6 +44,12 @@ In CVPR 2017. ## Getting Started ### Installation - Install PyTorch and dependencies from http://pytorch.org/ +- Install Torch vision from the source. +```bash +git clone https://github.com/pytorch/vision +cd vision +python setup.py install +``` - Install python libraries [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate). ```bash pip install visdom @@ -81,13 +87,13 @@ bash ./datasets/download_pix2pix_dataset.sh facades - Train a model: ```bash #!./scripts/train_pix2pix.sh -python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --align_data --use_dropout --no_lsgan +python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --dataset_mode aligned --use_dropout --no_lsgan ``` - To view training results and loss plots, run `python -m visdom.server` and click the URL http://localhost:8097. To see more intermediate results, check out `./checkpoints/facades_pix2pix/web/index.html` - Test the model (`bash ./scripts/test_pix2pix.sh`): ```bash #!./scripts/test_pix2pix.sh -python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --align_data +python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --dataset_mode aligned ``` The test results will be saved to a html file here: `./results/facades_pix2pix/latest_val/index.html`. diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py deleted file mode 100644 index d1d4572..0000000 --- a/data/aligned_data_loader.py +++ /dev/null @@ -1,87 +0,0 @@ -import random -import numpy as np -import torch.utils.data -import torchvision.transforms as transforms -from data.base_data_loader import BaseDataLoader -from data.image_folder import ImageFolder -from pdb import set_trace as st -# pip install future --upgrade -from builtins import object - -class PairedData(object): - def __init__(self, data_loader, fineSize, max_dataset_size, flip): - self.data_loader = data_loader - self.fineSize = fineSize - self.max_dataset_size = max_dataset_size - self.flip = flip - # st() - - def __iter__(self): - self.data_loader_iter = iter(self.data_loader) - self.iter = 0 - return self - - def __next__(self): - self.iter += 1 - if self.iter > self.max_dataset_size: - raise StopIteration - - AB, AB_paths = next(self.data_loader_iter) - w_total = AB.size(3) - w = int(w_total / 2) - h = AB.size(2) - - w_offset = random.randint(0, max(0, w - self.fineSize - 1)) - h_offset = random.randint(0, max(0, h - self.fineSize - 1)) - A = AB[:, :, h_offset:h_offset + self.fineSize, - w_offset:w_offset + self.fineSize] - B = AB[:, :, h_offset:h_offset + self.fineSize, - w + w_offset:w + w_offset + self.fineSize] - - if self.flip and random.random() < 0.5: - idx = [i for i in range(A.size(3) - 1, -1, -1)] - idx = torch.LongTensor(idx) - A = A.index_select(3, idx) - B = B.index_select(3, idx) - - - - return {'A': A, 'A_paths': AB_paths, 'B': B, 'B_paths': AB_paths} - - -class AlignedDataLoader(BaseDataLoader): - def initialize(self, opt): - BaseDataLoader.initialize(self, opt) - self.fineSize = opt.fineSize - - transformations = [ - # TODO: Scale - transforms.Scale(opt.loadSize), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] - transform = transforms.Compose(transformations) - - # Dataset A - dataset = ImageFolder(root=opt.dataroot + '/' + opt.phase, - transform=transform, return_paths=True) - data_loader = torch.utils.data.DataLoader( - dataset, - batch_size=self.opt.batchSize, - shuffle=not self.opt.serial_batches, - num_workers=int(self.opt.nThreads)) - - self.dataset = dataset - - flip = opt.isTrain and not opt.no_flip - self.paired_data = PairedData(data_loader, opt.fineSize, - opt.max_dataset_size, flip) - - def name(self): - return 'AlignedDataLoader' - - def load_data(self): - return self.paired_data - - def __len__(self): - return min(len(self.dataset), self.opt.max_dataset_size) diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py new file mode 100644 index 0000000..0f45c40 --- /dev/null +++ b/data/aligned_dataset.py @@ -0,0 +1,56 @@ +import os.path +import random +import torchvision.transforms as transforms +import torch +from data.base_dataset import BaseDataset +from data.image_folder import make_dataset +from PIL import Image + + +class AlignedDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_AB = os.path.join(opt.dataroot, opt.phase) + + self.AB_paths = sorted(make_dataset(self.dir_AB)) + + assert(opt.resize_or_crop == 'resize_and_crop') + + transform_list = [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + + self.transform = transforms.Compose(transform_list) + + def __getitem__(self, index): + AB_path = self.AB_paths[index] + AB = Image.open(AB_path).convert('RGB') + AB = AB.resize((self.opt.loadSize * 2, self.opt.loadSize), Image.BICUBIC) + AB = self.transform(AB) + + w_total = AB.size(2) + w = int(w_total / 2) + h = AB.size(1) + w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1)) + h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1)) + + A = AB[:, h_offset:h_offset + self.opt.fineSize, + w_offset:w_offset + self.opt.fineSize] + B = AB[:, h_offset:h_offset + self.opt.fineSize, + w + w_offset:w + w_offset + self.opt.fineSize] + + if (not self.opt.no_flip) and random.random() < 0.5: + idx = [i for i in range(A.size(2) - 1, -1, -1)] + idx = torch.LongTensor(idx) + A = A.index_select(2, idx) + B = B.index_select(2, idx) + + return {'A': A, 'B': B, + 'A_paths': AB_path, 'B_paths': AB_path} + + def __len__(self): + return len(self.AB_paths) + + def name(self): + return 'AlignedDataset' diff --git a/data/base_data_loader.py b/data/base_data_loader.py deleted file mode 100644 index 0e1deb5..0000000 --- a/data/base_data_loader.py +++ /dev/null @@ -1,14 +0,0 @@ - -class BaseDataLoader(): - def __init__(self): - pass - - def initialize(self, opt): - self.opt = opt - pass - - def load_data(): - return None - - - diff --git a/data/base_dataset.py b/data/base_dataset.py new file mode 100644 index 0000000..49b9d98 --- /dev/null +++ b/data/base_dataset.py @@ -0,0 +1,12 @@ +import torch.utils.data as data + +class BaseDataset(data.Dataset): + def __init__(self): + super(BaseDataset, self).__init__() + + def name(self): + return 'BaseDataset' + + def initialize(self, opt): + pass + diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py new file mode 100644 index 0000000..60180e0 --- /dev/null +++ b/data/custom_dataset_data_loader.py @@ -0,0 +1,41 @@ +import torch.utils.data +from data.base_data_loader import BaseDataLoader + + +def CreateDataset(opt): + dataset = None + if opt.dataset_mode == 'aligned': + from data.aligned_dataset import AlignedDataset + dataset = AlignedDataset() + elif opt.dataset_mode == 'unaligned': + from data.unaligned_dataset import UnalignedDataset + dataset = UnalignedDataset() + elif opt.dataset_mode == 'single': + from data.single_dataset import SingleDataset + dataset = SingleDataset() + else: + raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) + + print("dataset [%s] was created" % (dataset.name())) + dataset.initialize(opt) + return dataset + + +class CustomDatasetDataLoader(BaseDataLoader): + def name(self): + return 'CustomDatasetDataLoader' + + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.dataset = CreateDataset(opt) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads)) + + def load_data(self): + return self.dataloader + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) diff --git a/data/data_loader.py b/data/data_loader.py index 69035ea..2a4433a 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -1,12 +1,7 @@ def CreateDataLoader(opt): - data_loader = None - if opt.align_data > 0: - from data.aligned_data_loader import AlignedDataLoader - data_loader = AlignedDataLoader() - else: - from data.unaligned_data_loader import UnalignedDataLoader - data_loader = UnalignedDataLoader() + from data.custom_dataset_data_loader import CustomDatasetDataLoader + data_loader = CustomDatasetDataLoader() print(data_loader.name()) data_loader.initialize(opt) return data_loader diff --git a/data/image_folder.py b/data/image_folder.py index 44e15cb..898200b 100644 --- a/data/image_folder.py +++ b/data/image_folder.py @@ -1,9 +1,9 @@ -################################################################################ +############################################################################### # Code from # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py # Modified the original code so that it also loads images from the current # directory as well as the subdirectories -################################################################################ +############################################################################### import torch.utils.data as data @@ -45,7 +45,8 @@ class ImageFolder(data.Dataset): imgs = make_dataset(root) if len(imgs) == 0: raise(RuntimeError("Found 0 images in: " + root + "\n" - "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + "Supported image extensions are: " + + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs diff --git a/data/single_dataset.py b/data/single_dataset.py new file mode 100644 index 0000000..106bea3 --- /dev/null +++ b/data/single_dataset.py @@ -0,0 +1,47 @@ +import os.path +import torchvision.transforms as transforms +from data.base_dataset import BaseDataset +from data.image_folder import make_dataset +from PIL import Image + + +class SingleDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot) + + self.A_paths = make_dataset(self.dir_A) + + self.A_paths = sorted(self.A_paths) + + transform_list = [] + if opt.resize_or_crop == 'resize_and_crop': + transform_list.append(transforms.Scale(opt.loadSize)) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + if opt.resize_or_crop != 'no_resize': + transform_list.append(transforms.RandomCrop(opt.fineSize)) + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + + self.transform = transforms.Compose(transform_list) + + def __getitem__(self, index): + A_path = self.A_paths[index] + + A_img = Image.open(A_path).convert('RGB') + + A_img = self.transform(A_img) + + return {'A': A_img, 'A_paths': A_path} + + def __len__(self): + return len(self.A_paths) + + def name(self): + return 'SingleImageDataset' diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py deleted file mode 100644 index bd0ea75..0000000 --- a/data/unaligned_data_loader.py +++ /dev/null @@ -1,100 +0,0 @@ -import random -import torch.utils.data -import torchvision.transforms as transforms -from data.base_data_loader import BaseDataLoader -from data.image_folder import ImageFolder -# pip install future --upgrade -from builtins import object -from pdb import set_trace as st - -class PairedData(object): - def __init__(self, data_loader_A, data_loader_B, max_dataset_size, flip): - self.data_loader_A = data_loader_A - self.data_loader_B = data_loader_B - self.stop_A = False - self.stop_B = False - self.max_dataset_size = max_dataset_size - self.flip = flip - - def __iter__(self): - self.stop_A = False - self.stop_B = False - self.data_loader_A_iter = iter(self.data_loader_A) - self.data_loader_B_iter = iter(self.data_loader_B) - self.iter = 0 - return self - - def __next__(self): - A, A_paths = None, None - B, B_paths = None, None - try: - A, A_paths = next(self.data_loader_A_iter) - except StopIteration: - if A is None or A_paths is None: - self.stop_A = True - self.data_loader_A_iter = iter(self.data_loader_A) - A, A_paths = next(self.data_loader_A_iter) - - try: - B, B_paths = next(self.data_loader_B_iter) - except StopIteration: - if B is None or B_paths is None: - self.stop_B = True - self.data_loader_B_iter = iter(self.data_loader_B) - B, B_paths = next(self.data_loader_B_iter) - - if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size: - self.stop_A = False - self.stop_B = False - raise StopIteration() - else: - self.iter += 1 - if self.flip and random.random() < 0.5: - idx = [i for i in range(A.size(3) - 1, -1, -1)] - idx = torch.LongTensor(idx) - A = A.index_select(3, idx) - B = B.index_select(3, idx) - return {'A': A, 'A_paths': A_paths, - 'B': B, 'B_paths': B_paths} - -class UnalignedDataLoader(BaseDataLoader): - def initialize(self, opt): - BaseDataLoader.initialize(self, opt) - transformations = [transforms.Scale(opt.loadSize), - transforms.RandomCrop(opt.fineSize), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), - (0.5, 0.5, 0.5))] - transform = transforms.Compose(transformations) - - # Dataset A - dataset_A = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'A', - transform=transform, return_paths=True) - data_loader_A = torch.utils.data.DataLoader( - dataset_A, - batch_size=self.opt.batchSize, - shuffle=not self.opt.serial_batches, - num_workers=int(self.opt.nThreads)) - - # Dataset B - dataset_B = ImageFolder(root=opt.dataroot + '/' + opt.phase + 'B', - transform=transform, return_paths=True) - data_loader_B = torch.utils.data.DataLoader( - dataset_B, - batch_size=self.opt.batchSize, - shuffle=not self.opt.serial_batches, - num_workers=int(self.opt.nThreads)) - self.dataset_A = dataset_A - self.dataset_B = dataset_B - flip = opt.isTrain and not opt.no_flip - self.paired_data = PairedData(data_loader_A, data_loader_B, - self.opt.max_dataset_size, flip) - - def name(self): - return 'UnalignedDataLoader' - - def load_data(self): - return self.paired_data - - def __len__(self): - return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.max_dataset_size) diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py new file mode 100644 index 0000000..1f75b23 --- /dev/null +++ b/data/unaligned_dataset.py @@ -0,0 +1,56 @@ +import os.path +import torchvision.transforms as transforms +from data.base_dataset import BaseDataset +from data.image_folder import make_dataset +from PIL import Image +import PIL +from pdb import set_trace as st + + +class UnalignedDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') + self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') + + self.A_paths = make_dataset(self.dir_A) + self.B_paths = make_dataset(self.dir_B) + + self.A_paths = sorted(self.A_paths) + self.B_paths = sorted(self.B_paths) + + transform_list = [] + if opt.resize_or_crop == 'resize_and_crop': + osize = [opt.loadSize, opt.loadSize] + transform_list.append(transforms.Scale(osize, Image.BICUBIC)) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + if opt.resize_or_crop != 'no_resize': + transform_list.append(transforms.RandomCrop(opt.fineSize)) + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + self.transform = transforms.Compose(transform_list) + + def __getitem__(self, index): + A_path = self.A_paths[index] + B_path = self.B_paths[index] + + A_img = Image.open(A_path).convert('RGB') + B_img = Image.open(B_path).convert('RGB') + + A_img = self.transform(A_img) + B_img = self.transform(B_img) + + return {'A': A_img, 'B': B_img, + 'A_paths': A_path, 'B_paths': B_path} + + def __len__(self): + return min(len(self.A_paths), len(self.B_paths)) + + def name(self): + return 'UnalignedDataset' diff --git a/models/base_model.py b/models/base_model.py index 9b92bb4..36ceb43 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -1,6 +1,7 @@ import os import torch + class BaseModel(): def name(self): return 'BaseModel' diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index f8c4f9f..6fbb19f 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -10,6 +10,7 @@ from .base_model import BaseModel from . import networks import sys + class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' @@ -27,18 +28,18 @@ class CycleGANModel(BaseModel): # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, - opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) + opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, - opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) + opt.ngf, opt.which_model_netG, opt.norm, opt.use_dropout, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, - opt.which_model_netD, - opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, - opt.which_model_netD, - opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) @@ -58,10 +59,8 @@ class CycleGANModel(BaseModel): # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) - self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), - lr=opt.lr, betas=(opt.beta1, 0.999)) - self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), - lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) @@ -89,9 +88,9 @@ class CycleGANModel(BaseModel): self.real_B = Variable(self.input_B, volatile=True) self.fake_A = self.netG_B.forward(self.real_B) - self.rec_B = self.netG_A.forward(self.fake_A) + self.rec_B = self.netG_A.forward(self.fake_A) - #get image paths + # get image paths def get_image_paths(self): return self.image_paths @@ -114,7 +113,7 @@ class CycleGANModel(BaseModel): def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) - self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.identity @@ -167,7 +166,6 @@ class CycleGANModel(BaseModel): self.backward_D_B() self.optimizer_D_B.step() - def get_current_errors(self): D_A = self.loss_D_A.data[0] G_A = self.loss_G_A.data[0] @@ -187,10 +185,10 @@ class CycleGANModel(BaseModel): def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) - rec_A = util.tensor2im(self.rec_A.data) + rec_A = util.tensor2im(self.rec_A.data) real_B = util.tensor2im(self.real_B.data) fake_A = util.tensor2im(self.fake_A.data) - rec_B = util.tensor2im(self.rec_B.data) + rec_B = util.tensor2im(self.rec_B.data) if self.opt.identity > 0.0: idt_A = util.tensor2im(self.idt_A.data) idt_B = util.tensor2im(self.idt_B.data) diff --git a/models/models.py b/models/models.py index 8fea4f4..efcd898 100644 --- a/models/models.py +++ b/models/models.py @@ -3,15 +3,16 @@ def create_model(opt): model = None print(opt.model) if opt.model == 'cycle_gan': + assert(opt.dataset_mode == 'unaligned') from .cycle_gan_model import CycleGANModel - #assert(opt.align_data == False) model = CycleGANModel() elif opt.model == 'pix2pix': + assert(opt.dataset_mode == 'aligned') from .pix2pix_model import Pix2PixModel - assert(opt.align_data == True) model = Pix2PixModel() - elif opt.model == 'one_direction_test': - from .one_direction_test_model import OneDirectionTestModel + elif opt.model == 'test': + assert(opt.dataset_mode == 'single') + from .test_model import TestModel model = OneDirectionTestModel() else: raise ValueError("Model [%s] not recognized." % opt.model) diff --git a/models/one_direction_test_model.py b/models/test_model.py index d4f6442..a356263 100644 --- a/models/one_direction_test_model.py +++ b/models/test_model.py @@ -5,9 +5,9 @@ from .base_model import BaseModel from . import networks -class OneDirectionTestModel(BaseModel): +class TestModel(BaseModel): def name(self): - return 'OneDirectionTestModel' + return 'TestModel' def initialize(self, opt): BaseModel.initialize(self, opt) @@ -48,4 +48,3 @@ class OneDirectionTestModel(BaseModel): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) - diff --git a/options/base_options.py b/options/base_options.py index 619ca60..44100cf 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -2,6 +2,7 @@ import argparse import os from util import util + class BaseOptions(): def __init__(self): self.parser = argparse.ArgumentParser() @@ -21,10 +22,9 @@ class BaseOptions(): self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2') self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') - self.parser.add_argument('--align_data', action='store_true', - help='if True, the datasets are loaded from "test" and "train" directories and the data pairs are aligned') + self.parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single]') self.parser.add_argument('--model', type=str, default='cycle_gan', - help='chooses which model to use. cycle_gan, one_direction_test, pix2pix, ...') + help='chooses which model to use. cycle_gan, pix2pix, test') self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') @@ -35,6 +35,8 @@ class BaseOptions(): self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width]') + self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') self.initialized = True @@ -59,7 +61,7 @@ class BaseOptions(): print('-------------- End ----------------') # save to the disk - expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) + expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) util.mkdirs(expr_dir) file_name = os.path.join(expr_dir, 'opt.txt') with open(file_name, 'wt') as opt_file: diff --git a/options/test_options.py b/options/test_options.py index c4ecff6..6b79860 100644 --- a/options/test_options.py +++ b/options/test_options.py @@ -1,5 +1,6 @@ from .base_options import BaseOptions + class TestOptions(BaseOptions): def initialize(self): BaseOptions.initialize(self) diff --git a/options/train_options.py b/options/train_options.py index a1d347f..345f619 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -1,5 +1,6 @@ from .base_options import BaseOptions + class TrainOptions(BaseOptions): def initialize(self): BaseOptions.initialize(self) @@ -19,7 +20,4 @@ class TrainOptions(BaseOptions): self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') - self.parser.add_argument('--no_flip' , action='store_true', help='if specified, do not flip the images for data argumentation') - - # NOT-IMPLEMENTED self.parser.add_argument('--preprocessing', type=str, default='resize_and_crop', help='resizing/cropping strategy') self.isTrain = True diff --git a/scripts/test_pix2pix.sh b/scripts/test_pix2pix.sh index 0d19934..b821878 100644 --- a/scripts/test_pix2pix.sh +++ b/scripts/test_pix2pix.sh @@ -1 +1 @@ -python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --align_data --use_dropout +python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --dataset_mode aligned --use_dropout diff --git a/scripts/train_pix2pix.sh b/scripts/train_pix2pix.sh index f14e7da..bf45c84 100644 --- a/scripts/train_pix2pix.sh +++ b/scripts/train_pix2pix.sh @@ -1 +1 @@ -python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --align_data --no_lsgan --use_dropout +python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --lambda_A 100 --dataset_mode aligned --no_lsgan --use_dropout @@ -12,6 +12,7 @@ from util import html opt.nThreads = 1 # test code only supports nThreads=1 opt.batchSize = 1 #test code only supports batchSize=1 opt.serial_batches = True # no shuffle +opt.no_flip = True data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() |
