summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/cycle_gan_model.py6
-rw-r--r--models/test_model.py2
-rw-r--r--options/train_options.py6
-rw-r--r--train.py6
-rw-r--r--util/visualizer.py4
5 files changed, 14 insertions, 10 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index 730d077..85432bb 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -121,7 +121,7 @@ class CycleGANModel(BaseModel):
self.loss_D_B = loss_D_B.data[0]
def backward_G(self):
- lambda_idt = self.opt.identity
+ lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
@@ -193,7 +193,7 @@ class CycleGANModel(BaseModel):
def get_current_errors(self):
ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A),
('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
- if self.opt.identity > 0.0:
+ if self.opt.lambda_identity > 0.0:
ret_errors['idt_A'] = self.loss_idt_A
ret_errors['idt_B'] = self.loss_idt_B
return ret_errors
@@ -207,7 +207,7 @@ class CycleGANModel(BaseModel):
rec_B = util.tensor2im(self.rec_B)
ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
- if self.opt.isTrain and self.opt.identity > 0.0:
+ if self.opt.isTrain and self.opt.lambda_identity > 0.0:
ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
return ret_visuals
diff --git a/models/test_model.py b/models/test_model.py
index 35b7402..f593c46 100644
--- a/models/test_model.py
+++ b/models/test_model.py
@@ -28,7 +28,7 @@ class TestModel(BaseModel):
# we need to use single_dataset mode
input_A = input['A']
if len(self.gpu_ids) > 0:
- input_A.cuda(self.gpu_ids[0], async=True)
+ input_A = input_A.cuda(self.gpu_ids[0], async=True)
self.input_A = input_A
self.image_paths = input['A_paths']
diff --git a/options/train_options.py b/options/train_options.py
index f4627ce..3d05a2b 100644
--- a/options/train_options.py
+++ b/options/train_options.py
@@ -21,12 +21,12 @@ class TrainOptions(BaseOptions):
self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
+ self.parser.add_argument('--lambda_identity', type=float, default=0.5,
+ help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss.'
+ 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
- self.parser.add_argument('--identity', type=float, default=0.5,
- help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss.'
- 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1')
self.isTrain = True
diff --git a/train.py b/train.py
index f6072c7..61b596a 100644
--- a/train.py
+++ b/train.py
@@ -16,10 +16,13 @@ total_steps = 0
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
+ iter_data_time = time.time()
epoch_iter = 0
for i, data in enumerate(dataset):
iter_start_time = time.time()
+ if total_steps % opt.print_freq == 0:
+ t_data = iter_start_time - iter_data_time
visualizer.reset()
total_steps += opt.batchSize
epoch_iter += opt.batchSize
@@ -33,7 +36,7 @@ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
if total_steps % opt.print_freq == 0:
errors = model.get_current_errors()
t = (time.time() - iter_start_time) / opt.batchSize
- visualizer.print_current_errors(epoch, epoch_iter, errors, t)
+ visualizer.print_current_errors(epoch, epoch_iter, errors, t, t_data)
if opt.display_id > 0:
visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors)
@@ -42,6 +45,7 @@ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
(epoch, total_steps))
model.save('latest')
+ iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' %
(epoch, total_steps))
diff --git a/util/visualizer.py b/util/visualizer.py
index b22f235..fd8140d 100644
--- a/util/visualizer.py
+++ b/util/visualizer.py
@@ -114,8 +114,8 @@ class Visualizer():
win=self.display_id)
# errors: same format as |errors| of plotCurrentErrors
- def print_current_errors(self, epoch, i, errors, t):
- message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
+ def print_current_errors(self, epoch, i, errors, t, t_data):
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
for k, v in errors.items():
message += '%s: %.3f ' % (k, v)