diff options
| -rw-r--r-- | models/cycle_gan_model.py | 6 | ||||
| -rw-r--r-- | models/test_model.py | 2 | ||||
| -rw-r--r-- | options/train_options.py | 6 | ||||
| -rw-r--r-- | train.py | 6 | ||||
| -rw-r--r-- | util/visualizer.py | 4 |
5 files changed, 14 insertions, 10 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 730d077..85432bb 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -121,7 +121,7 @@ class CycleGANModel(BaseModel): self.loss_D_B = loss_D_B.data[0] def backward_G(self): - lambda_idt = self.opt.identity + lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss @@ -193,7 +193,7 @@ class CycleGANModel(BaseModel): def get_current_errors(self): ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) - if self.opt.identity > 0.0: + if self.opt.lambda_identity > 0.0: ret_errors['idt_A'] = self.loss_idt_A ret_errors['idt_B'] = self.loss_idt_B return ret_errors @@ -207,7 +207,7 @@ class CycleGANModel(BaseModel): rec_B = util.tensor2im(self.rec_B) ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) - if self.opt.isTrain and self.opt.identity > 0.0: + if self.opt.isTrain and self.opt.lambda_identity > 0.0: ret_visuals['idt_A'] = util.tensor2im(self.idt_A) ret_visuals['idt_B'] = util.tensor2im(self.idt_B) return ret_visuals diff --git a/models/test_model.py b/models/test_model.py index 35b7402..f593c46 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -28,7 +28,7 @@ class TestModel(BaseModel): # we need to use single_dataset mode input_A = input['A'] if len(self.gpu_ids) > 0: - input_A.cuda(self.gpu_ids[0], async=True) + input_A = input_A.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.image_paths = input['A_paths'] diff --git a/options/train_options.py b/options/train_options.py index f4627ce..3d05a2b 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -21,12 +21,12 @@ class TrainOptions(BaseOptions): self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') + self.parser.add_argument('--lambda_identity', type=float, default=0.5, + help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss.' + 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') - self.parser.add_argument('--identity', type=float, default=0.5, - help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss.' - 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') self.isTrain = True @@ -16,10 +16,13 @@ total_steps = 0 for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() + iter_data_time = time.time() epoch_iter = 0 for i, data in enumerate(dataset): iter_start_time = time.time() + if total_steps % opt.print_freq == 0: + t_data = iter_start_time - iter_data_time visualizer.reset() total_steps += opt.batchSize epoch_iter += opt.batchSize @@ -33,7 +36,7 @@ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): if total_steps % opt.print_freq == 0: errors = model.get_current_errors() t = (time.time() - iter_start_time) / opt.batchSize - visualizer.print_current_errors(epoch, epoch_iter, errors, t) + visualizer.print_current_errors(epoch, epoch_iter, errors, t, t_data) if opt.display_id > 0: visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors) @@ -42,6 +45,7 @@ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): (epoch, total_steps)) model.save('latest') + iter_data_time = time.time() if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) diff --git a/util/visualizer.py b/util/visualizer.py index b22f235..fd8140d 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -114,8 +114,8 @@ class Visualizer(): win=self.display_id) # errors: same format as |errors| of plotCurrentErrors - def print_current_errors(self, epoch, i, errors, t): - message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) + def print_current_errors(self, epoch, i, errors, t, t_data): + message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) for k, v in errors.items(): message += '%s: %.3f ' % (k, v) |
