summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md2
-rw-r--r--data/aligned_data_loader.py10
-rw-r--r--data/unaligned_data_loader.py10
-rw-r--r--models/models.py9
-rw-r--r--models/one_direction_test_model.py51
-rw-r--r--options/base_options.py2
-rw-r--r--options/train_options.py1
7 files changed, 71 insertions, 14 deletions
diff --git a/README.md b/README.md
index abccd2f..4bb1595 100644
--- a/README.md
+++ b/README.md
@@ -87,7 +87,7 @@ python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix
- Test the model (`bash ./scripts/test_pix2pix.sh`):
```bash
#!./scripts/test_pix2pix.sh
-python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --align_data --use_dropout
+python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --which_model_netG unet_256 --which_direction BtoA --align_data
```
The test results will be saved to a html file here: `./results/facades_pix2pix/latest_val/index.html`.
diff --git a/data/aligned_data_loader.py b/data/aligned_data_loader.py
index b7a228b..a1efde8 100644
--- a/data/aligned_data_loader.py
+++ b/data/aligned_data_loader.py
@@ -8,10 +8,10 @@ from pdb import set_trace as st
from builtins import object
class PairedData(object):
- def __init__(self, data_loader, fineSize, ntrain):
+ def __init__(self, data_loader, fineSize, max_dataset_size):
self.data_loader = data_loader
self.fineSize = fineSize
- self.ntrain = ntrain
+ self.max_dataset_size = max_dataset_size
# st()
def __iter__(self):
@@ -21,7 +21,7 @@ class PairedData(object):
def __next__(self):
self.iter += 1
- if self.iter > self.ntrain:
+ if self.iter > self.max_dataset_size:
raise StopIteration
AB, AB_paths = next(self.data_loader_iter)
@@ -60,7 +60,7 @@ class AlignedDataLoader(BaseDataLoader):
num_workers=int(self.opt.nThreads))
self.dataset = dataset
- self.paired_data = PairedData(data_loader, opt.fineSize, opt.ntrain)
+ self.paired_data = PairedData(data_loader, opt.fineSize, opt.max_dataset_size)
def name(self):
return 'AlignedDataLoader'
@@ -69,4 +69,4 @@ class AlignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return min(len(self.dataset), self.opt.ntrain)
+ return min(len(self.dataset), self.opt.max_dataset_size)
diff --git a/data/unaligned_data_loader.py b/data/unaligned_data_loader.py
index 4a29c7c..77f9274 100644
--- a/data/unaligned_data_loader.py
+++ b/data/unaligned_data_loader.py
@@ -7,12 +7,12 @@ from builtins import object
from pdb import set_trace as st
class PairedData(object):
- def __init__(self, data_loader_A, data_loader_B, ntrain):
+ def __init__(self, data_loader_A, data_loader_B, max_dataset_size):
self.data_loader_A = data_loader_A
self.data_loader_B = data_loader_B
self.stop_A = False
self.stop_B = False
- self.ntrain = ntrain
+ self.max_dataset_size = max_dataset_size
def __iter__(self):
self.stop_A = False
@@ -41,7 +41,7 @@ class PairedData(object):
self.data_loader_B_iter = iter(self.data_loader_B)
B, B_paths = next(self.data_loader_B_iter)
- if (self.stop_A and self.stop_B) or self.iter > self.ntrain:
+ if (self.stop_A and self.stop_B) or self.iter > self.max_dataset_size:
self.stop_A = False
self.stop_B = False
raise StopIteration()
@@ -79,7 +79,7 @@ class UnalignedDataLoader(BaseDataLoader):
num_workers=int(self.opt.nThreads))
self.dataset_A = dataset_A
self.dataset_B = dataset_B
- self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.ntrain)
+ self.paired_data = PairedData(data_loader_A, data_loader_B, self.opt.max_dataset_size)
def name(self):
return 'UnalignedDataLoader'
@@ -88,4 +88,4 @@ class UnalignedDataLoader(BaseDataLoader):
return self.paired_data
def __len__(self):
- return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.ntrain)
+ return min(max(len(self.dataset_A), len(self.dataset_B)), self.opt.max_dataset_size)
diff --git a/models/models.py b/models/models.py
index 7e790d0..8fea4f4 100644
--- a/models/models.py
+++ b/models/models.py
@@ -4,12 +4,17 @@ def create_model(opt):
print(opt.model)
if opt.model == 'cycle_gan':
from .cycle_gan_model import CycleGANModel
- assert(opt.align_data == False)
+ #assert(opt.align_data == False)
model = CycleGANModel()
- if opt.model == 'pix2pix':
+ elif opt.model == 'pix2pix':
from .pix2pix_model import Pix2PixModel
assert(opt.align_data == True)
model = Pix2PixModel()
+ elif opt.model == 'one_direction_test':
+ from .one_direction_test_model import OneDirectionTestModel
+ model = OneDirectionTestModel()
+ else:
+ raise ValueError("Model [%s] not recognized." % opt.model)
model.initialize(opt)
print("model [%s] was created" % (model.name()))
return model
diff --git a/models/one_direction_test_model.py b/models/one_direction_test_model.py
new file mode 100644
index 0000000..37e1bbb
--- /dev/null
+++ b/models/one_direction_test_model.py
@@ -0,0 +1,51 @@
+from torch.autograd import Variable
+from collections import OrderedDict
+import util.util as util
+from .base_model import BaseModel
+from . import networks
+
+
+class OneDirectionTestModel(BaseModel):
+ def name(self):
+ return 'OneDirectionTestModel'
+
+ def initialize(self, opt):
+ BaseModel.initialize(self, opt)
+
+ nb = opt.batchSize
+ size = opt.fineSize
+ self.input_A = self.Tensor(nb, opt.input_nc, size, size)
+
+ assert(not self.isTrain)
+ self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
+ opt.ngf, opt.which_model_netG,
+ opt.norm, opt.use_dropout,
+ self.gpu_ids)
+ which_epoch = opt.which_epoch
+ AtoB = self.opt.which_direction == 'AtoB'
+ which_network = 'G_A' if AtoB else 'G_B'
+ self.load_network(self.netG_A, which_network, which_epoch)
+
+ print('---------- Networks initialized -------------')
+ networks.print_network(self.netG_A)
+ print('-----------------------------------------------')
+
+ def set_input(self, input):
+ AtoB = self.opt.which_direction == 'AtoB'
+ input_A = input['A' if AtoB else 'B']
+ self.input_A.resize_(input_A.size()).copy_(input_A)
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
+
+ def test(self):
+ self.real_A = Variable(self.input_A)
+ self.fake_B = self.netG_A.forward(self.real_A)
+
+ #get image paths
+ def get_image_paths(self):
+ return self.image_paths
+
+ def get_current_visuals(self):
+ real_A = util.tensor2im(self.real_A.data)
+ fake_B = util.tensor2im(self.fake_B.data)
+ return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])
+
diff --git a/options/base_options.py b/options/base_options.py
index f03a687..bce0b9c 100644
--- a/options/base_options.py
+++ b/options/base_options.py
@@ -35,6 +35,8 @@ class BaseOptions():
self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1')
self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
+ self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
+
self.initialized = True
def parse(self):
diff --git a/options/train_options.py b/options/train_options.py
index 981126f..4b4eac3 100644
--- a/options/train_options.py
+++ b/options/train_options.py
@@ -13,7 +13,6 @@ class TrainOptions(BaseOptions):
self.parser.add_argument('--niter', type=int, default=200, help='# of iter at starting learning rate')
self.parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
- self.parser.add_argument('--ntrain', type=int, default=float("inf"), help='# of examples per epoch.')
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')