diff options
| -rw-r--r-- | README.md | 48 | ||||
| -rw-r--r-- | data/custom_dataset_data_loader.py | 8 | ||||
| -rw-r--r-- | data/unaligned_dataset.py | 5 | ||||
| -rw-r--r-- | models/base_model.py | 3 | ||||
| -rw-r--r-- | models/cycle_gan_model.py | 143 | ||||
| -rw-r--r-- | models/networks.py | 23 | ||||
| -rw-r--r-- | models/pix2pix_model.py | 16 | ||||
| -rw-r--r-- | models/test_model.py | 2 | ||||
| -rw-r--r-- | options/base_options.py | 6 | ||||
| -rw-r--r-- | options/train_options.py | 4 | ||||
| -rw-r--r-- | train.py | 4 | ||||
| -rw-r--r-- | util/image_pool.py | 6 | ||||
| -rw-r--r-- | util/visualizer.py | 38 |
13 files changed, 182 insertions, 124 deletions
@@ -34,7 +34,31 @@ Image-to-Image Translation with Conditional Adversarial Networks [Phillip Isola](https://people.eecs.berkeley.edu/~isola), [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz), [Tinghui Zhou](https://people.eecs.berkeley.edu/~tinghuiz), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros) In CVPR 2017. +## Other implementations: +### CycleGAN +<p><a href="https://github.com/leehomyc/cyclegan-1"> [Tensorflow]</a> (by Harry Yang), +<a href="https://github.com/architrathore/CycleGAN/">[Tensorflow]</a> (by Archit Rathore), +<a href="https://github.com/vanhuyz/CycleGAN-TensorFlow">[Tensorflow]</a> (by Van Huy), +<a href="https://github.com/XHUJOY/CycleGAN-tensorflow">[Tensorflow]</a> (by Xiaowei Hu), +<a href="https://github.com/LynnHo/CycleGAN-Tensorflow-Simple"> [Tensorflow-simple]</a> (by Zhenliang He), +<a href="https://github.com/luoxier/CycleGAN_Tensorlayer"> [TensorLayer]</a> (by luoxier), +<a href="https://github.com/Aixile/chainer-cyclegan">[Chainer]</a> (by Yanghua Jin), +<a href="https://github.com/yunjey/mnist-svhn-transfer">[Minimal PyTorch]</a> (by yunjey), +<a href="https://github.com/Ldpe2G/DeepLearningForFun/tree/master/Mxnet-Scala/CycleGAN">[Mxnet]</a> (by Ldpe2G), +<a href="https://github.com/tjwei/GANotebooks">[lasagne/keras]</a> (by tjwei)</p> +</ul> +### pix2pix +<p><a href="https://github.com/affinelayer/pix2pix-tensorflow"> [Tensorflow]</a> (by Christopher Hesse), +<a href="https://github.com/Eyyub/tensorflow-pix2pix">[Tensorflow]</a> (by Eyyüb Sariu), +<a href="https://github.com/datitran/face2face-demo"> [Tensorflow (face2face)]</a> (by Dat Tran), +<a href="https://github.com/awjuliani/Pix2Pix-Film"> [Tensorflow (film)]</a> (by Arthur Juliani), +<a href="https://github.com/kaonashi-tyc/zi2zi">[Tensorflow (zi2zi)]</a> (by Yuchen Tian), +<a href="https://github.com/pfnet-research/chainer-pix2pix">[Chainer]</a> (by mattya), +<a href="https://github.com/tjwei/GANotebooks">[tf/torch/keras/lasagne]</a> (by tjwei), +<a href="https://github.com/taey16/pix2pixBEGAN.pytorch">[Pytorch]</a> (by taey16) +</p> +</ul> ## Prerequisites - Linux or macOS @@ -100,17 +124,31 @@ The test results will be saved to a html file here: `./results/facades_pix2pix/l More example scripts can be found at `scripts` directory. ### Apply a pre-trained model (CycleGAN) -If you would like to apply a pre-trained model to a collection of input photos (without image pairs), please use `--dataset_mode single` and `--model test` options. Here is a script to apply a pix2pix model to facade label maps (stored in the directory `facades/testB`). + +If you would like to apply a pre-trained model to a collection of input photos (without image pairs), please use `--dataset_mode single` and `--model test` options. Here is a script to apply a model to Facade label maps (stored in the directory `facades/testB`). ``` bash #!./scripts/test_single.sh -python test.py --dataroot ./datasets/facades/testB/ --name facades_pix2pix --model test --which_model_netG unet_256 --which_direction BtoA --dataset_mode single +python test.py --dataroot ./datasets/facades/testA/ --name {my_trained_model_name} --model test --dataset_mode single +``` +You might want to specify `--which_model_netG` to match the generator architecture of the trained model. + +You can download a few pretrained models from the authors. For example, if you would like to download horse2zebra model, + +```bash +bash pretrained_models/download_cyclegan_model.sh horse2zebra +``` +The pretrained model is saved at `./checkpoints/{name}_pretrained/latest_net_G.pth`. +Then generate the results using + +```bash +python test.py --dataroot datasets/horse2zebra/testA --checkpoints_dir ./checkpoints/ --name horse2zebra_pretrained --no_dropout --model test --dataset_mode single --loadSize 256 --results_dir {directory_path_to_save_result} ``` -Note: We currently don't have pretrained models using PyTorch. This is in part because the models trained using Torch and PyTorch produce slightly different results, although we were not able to decide which result is better. If you would like to generate the same results that appeared in our paper, we recommend using the pretrained models in the Torch codebase. +Note: We currently don't have all pretrained models using PyTorch. This is in part because the models trained using Torch and PyTorch produce slightly different results, although we were not able to decide which result is better. If you would like to generate the same results that appeared in our paper, we recommend using the pretrained models in the Torch codebase. ### Apply a pre-trained model (pix2pix) -Download the pre-trained models using `./pretrained_models/download_pix2pix_model.sh`. For example, if you would like to download label2photo model on the Facades dataset, +Download the pre-trained models using `./pretrained_models/download_pix2pix_model.sh`. For example, if you would like to download label2photo model on the Facades dataset, ```bash bash pretrained_models/download_pix2pix_model.sh facades_label2photo @@ -120,7 +158,7 @@ Then generate the results using ```bash python test.py --dataroot ./datasets/facades/ --which_direction BtoA --model pix2pix --name facades_label2photo_pretrained --dataset_mode aligned --which_model_netG unet_256 --norm batch ``` -Note that we specified `--which_direction BtoA` to accomodate the fact that the Facades dataset's A to B direction is photos to labels. +Note that we specified `--which_direction BtoA` to accomodate the fact that the Facades dataset's A to B direction is photos to labels. Also, the models that are currently available to download can be found by reading the output of `bash pretrained_models/download_pix2pix_model.sh` diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py index 60180e0..787946f 100644 --- a/data/custom_dataset_data_loader.py +++ b/data/custom_dataset_data_loader.py @@ -35,7 +35,13 @@ class CustomDatasetDataLoader(BaseDataLoader): num_workers=int(opt.nThreads)) def load_data(self): - return self.dataloader + return self def __len__(self): return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + for i, data in enumerate(self.dataloader): + if i >= self.opt.max_dataset_size: + break + yield data diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py index c5e5460..ad0c11b 100644 --- a/data/unaligned_dataset.py +++ b/data/unaligned_dataset.py @@ -25,7 +25,10 @@ class UnalignedDataset(BaseDataset): def __getitem__(self, index): A_path = self.A_paths[index % self.A_size] index_A = index % self.A_size - index_B = random.randint(0, self.B_size - 1) + if self.opt.serial_batches: + index_B = index % self.B_size + else: + index_B = random.randint(0, self.B_size - 1) B_path = self.B_paths[index_B] # print('(A, B) = (%d, %d)' % (index_A, index_B)) A_img = Image.open(A_path).convert('RGB') diff --git a/models/base_model.py b/models/base_model.py index 446a903..9b55afe 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -44,13 +44,14 @@ class BaseModel(): save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if len(gpu_ids) and torch.cuda.is_available(): - network.cuda(device_id=gpu_ids[0]) + network.cuda(gpu_ids[0]) # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) network.load_state_dict(torch.load(save_path)) + # update learning rate (called once every epoch) def update_learning_rate(self): for scheduler in self.schedulers: diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index 74771cf..fe06823 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -44,9 +44,9 @@ class CycleGANModel(BaseModel): which_epoch = opt.which_epoch self.load_network(self.netG_A, 'G_A', which_epoch) self.load_network(self.netG_B, 'G_B', which_epoch) - #if self.isTrain: - # self.load_network(self.netD_A, 'D_A', which_epoch) - # self.load_network(self.netD_B, 'D_B', which_epoch) + if self.isTrain: + self.load_network(self.netD_A, 'D_A', which_epoch) + self.load_network(self.netD_B, 'D_B', which_epoch) if self.isTrain: self.old_lr = opt.lr @@ -77,8 +77,6 @@ class CycleGANModel(BaseModel): networks.print_network(self.netD_B) print('-----------------------------------------------') - self.step_count = 0 - def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] @@ -86,20 +84,21 @@ class CycleGANModel(BaseModel): self.input_A.resize_(input_A.size()).copy_(input_A) self.input_B.resize_(input_B.size()).copy_(input_B) self.image_paths = input['A_paths' if AtoB else 'B_paths'] - self.image_paths2 = input['B_paths' if AtoB else 'A_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): - self.real_A = Variable(self.input_A, volatile=True) - self.fake_B = self.netG_A.forward(self.real_A) - self.rec_A = self.netG_B.forward(self.fake_B) + real_A = Variable(self.input_A, volatile=True) + fake_B = self.netG_A(real_A) + self.rec_A = self.netG_B(fake_B).data + self.fake_B = fake_B.data - self.real_B = Variable(self.input_B, volatile=True) - self.fake_A = self.netG_B.forward(self.real_B) - self.rec_B = self.netG_A.forward(self.fake_A) + real_B = Variable(self.input_B, volatile=True) + fake_A = self.netG_B(real_B) + self.rec_B = self.netG_A(fake_A).data + self.fake_A = fake_A.data # get image paths def get_image_paths(self): @@ -107,10 +106,10 @@ class CycleGANModel(BaseModel): def backward_D_basic(self, netD, real, fake): # Real - pred_real = netD.forward(real) + pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake - pred_fake = netD.forward(fake.detach()) + pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 @@ -120,11 +119,13 @@ class CycleGANModel(BaseModel): def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) - self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + self.loss_D_A = loss_D_A.data[0] def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) - self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + self.loss_D_B = loss_D_B.data[0] def backward_G(self): lambda_idt = self.opt.identity @@ -133,51 +134,59 @@ class CycleGANModel(BaseModel): # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. - self.idt_A = self.netG_A.forward(self.real_B) - self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt + idt_A = self.netG_A(self.real_B) + loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. - self.idt_B = self.netG_B.forward(self.real_A) - self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt + idt_B = self.netG_B(self.real_A) + loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt + + self.idt_A = idt_A.data + self.idt_B = idt_B.data + self.loss_idt_A = loss_idt_A.data[0] + self.loss_idt_B = loss_idt_B.data[0] else: + loss_idt_A = 0 + loss_idt_B = 0 self.loss_idt_A = 0 self.loss_idt_B = 0 - - # GAN loss - # D_A(G_A(A)) - self.fake_B = self.netG_A.forward(self.real_A) - pred_fake = self.netD_A.forward(self.fake_B) - self.loss_G_A = self.criterionGAN(pred_fake, True) - # D_B(G_B(B)) - self.fake_A = self.netG_B.forward(self.real_B) - pred_fake = self.netD_B.forward(self.fake_A) - self.loss_G_B = self.criterionGAN(pred_fake, True) + + # GAN loss D_A(G_A(A)) + fake_B = self.netG_A(self.real_A) + pred_fake = self.netD_A(fake_B) + loss_G_A = self.criterionGAN(pred_fake, True) + + # GAN loss D_B(G_B(B)) + fake_A = self.netG_B(self.real_B) + pred_fake = self.netD_B(fake_A) + loss_G_B = self.criterionGAN(pred_fake, True) # Forward cycle loss - self.rec_A = self.netG_B.forward(self.fake_B) - self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A + rec_A = self.netG_B(fake_B) + loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A + # Backward cycle loss - self.rec_B = self.netG_A.forward(self.fake_A) - self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B + rec_B = self.netG_A(fake_A) + loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B # combined loss - self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B - self.loss_G.backward() + loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B + loss_G.backward() + + self.fake_B = fake_B.data + self.fake_A = fake_A.data + self.rec_A = rec_A.data + self.rec_B = rec_B.data + + self.loss_G_A = loss_G_A.data[0] + self.loss_G_B = loss_G_B.data[0] + self.loss_cycle_A = loss_cycle_A.data[0] + self.loss_cycle_B = loss_cycle_B.data[0] def optimize_parameters(self): - self.step_count += 1 # forward self.forward() # G_A and G_B self.optimizer_G.zero_grad() self.backward_G() - if (self.loss_G != self.loss_G).sum().data[0] > 0: - exit(1) - #for w in self.netG_A.parameters(): - #print(w.grad.data) - # if (w.grad.data != w.grad.data).sum() > 0: - # print(w.grad.data) - # exit(1) - #print(self.image_paths, self.image_paths2) - #return self.optimizer_G.step() # D_A self.optimizer_D_A.zero_grad() @@ -189,36 +198,26 @@ class CycleGANModel(BaseModel): self.optimizer_D_B.step() def get_current_errors(self): - D_A = self.loss_D_A.data[0] - G_A = self.loss_G_A.data[0] - Cyc_A = self.loss_cycle_A.data[0] - D_B = self.loss_D_B.data[0] - G_B = self.loss_G_B.data[0] - Cyc_B = self.loss_cycle_B.data[0] + ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A), + ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)]) if self.opt.identity > 0.0: - idt_A = self.loss_idt_A.data[0] - idt_B = self.loss_idt_B.data[0] - return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), - ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) - else: - return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), - ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) + ret_errors['idt_A'] = self.loss_idt_A + ret_errors['idt_B'] = self.loss_idt_B + return ret_errors def get_current_visuals(self): - real_A = util.tensor2im(self.real_A.data) - fake_B = util.tensor2im(self.fake_B.data) - rec_A = util.tensor2im(self.rec_A.data) - real_B = util.tensor2im(self.real_B.data) - fake_A = util.tensor2im(self.fake_A.data) - rec_B = util.tensor2im(self.rec_B.data) + real_A = util.tensor2im(self.input_A) + fake_B = util.tensor2im(self.fake_B) + rec_A = util.tensor2im(self.rec_A) + real_B = util.tensor2im(self.input_B) + fake_A = util.tensor2im(self.fake_A) + rec_B = util.tensor2im(self.rec_B) + ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) if self.opt.isTrain and self.opt.identity > 0.0: - idt_A = util.tensor2im(self.idt_A.data) - idt_B = util.tensor2im(self.idt_B.data) - return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), - ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) - else: - return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), - ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) + ret_visuals['idt_A'] = util.tensor2im(self.idt_A) + ret_visuals['idt_B'] = util.tensor2im(self.idt_B) + return ret_visuals def save(self, label): self.save_network(self.netG_A, 'G_A', label, self.gpu_ids) diff --git a/models/networks.py b/models/networks.py index 965bacb..568f8c9 100644 --- a/models/networks.py +++ b/models/networks.py @@ -10,16 +10,15 @@ import numpy as np ############################################################################### - def weights_init_normal(m): classname = m.__class__.__name__ # print(classname) if classname.find('Conv') != -1: - init.uniform(m.weight.data, 0.0, 0.02) + init.normal(m.weight.data, 0.0, 0.02) elif classname.find('Linear') != -1: - init.uniform(m.weight.data, 0.0, 0.02) + init.normal(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm2d') != -1: - init.uniform(m.weight.data, 1.0, 0.02) + init.normal(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) @@ -27,11 +26,11 @@ def weights_init_xavier(m): classname = m.__class__.__name__ # print(classname) if classname.find('Conv') != -1: - init.xavier_normal(m.weight.data, gain=1) + init.xavier_normal(m.weight.data, gain=0.02) elif classname.find('Linear') != -1: - init.xavier_normal(m.weight.data, gain=1) + init.xavier_normal(m.weight.data, gain=0.02) elif classname.find('BatchNorm2d') != -1: - init.uniform(m.weight.data, 1.0, 0.02) + init.normal(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) @@ -43,7 +42,7 @@ def weights_init_kaiming(m): elif classname.find('Linear') != -1: init.kaiming_normal(m.weight.data, a=0, mode='fan_in') elif classname.find('BatchNorm2d') != -1: - init.uniform(m.weight.data, 1.0, 0.02) + init.normal(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) @@ -55,7 +54,7 @@ def weights_init_orthogonal(m): elif classname.find('Linear') != -1: init.orthogonal(m.weight.data, gain=1) elif classname.find('BatchNorm2d') != -1: - init.uniform(m.weight.data, 1.0, 0.02) + init.normal(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) @@ -88,7 +87,7 @@ def get_norm_layer(norm_type='instance'): def get_scheduler(optimizer, opt): if opt.lr_policy == 'lambda': def lambda_rule(epoch): - lr_l = 1.0 - max(0, epoch - opt.niter) / float(opt.niter_decay+1) + lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) return lr_l scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) elif opt.lr_policy == 'step': @@ -119,7 +118,7 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo else: raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: - netG.cuda(device_id=gpu_ids[0]) + netG.cuda(gpu_ids[0]) init_weights(netG, init_type=init_type) return netG @@ -142,7 +141,7 @@ def define_D(input_nc, ndf, which_model_netD, raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) if use_gpu: - netD.cuda(device_id=gpu_ids[0]) + netD.cuda(gpu_ids[0]) init_weights(netD, init_type=init_type) return netD diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index 18ba53f..56adfc1 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -70,13 +70,13 @@ class Pix2PixModel(BaseModel): def forward(self): self.real_A = Variable(self.input_A) - self.fake_B = self.netG.forward(self.real_A) + self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B) # no backprop gradients def test(self): self.real_A = Variable(self.input_A, volatile=True) - self.fake_B = self.netG.forward(self.real_A) + self.fake_B = self.netG(self.real_A) self.real_B = Variable(self.input_B, volatile=True) # get image paths @@ -86,14 +86,14 @@ class Pix2PixModel(BaseModel): def backward_D(self): # Fake # stop backprop to the generator by detaching fake_B - fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) - self.pred_fake = self.netD.forward(fake_AB.detach()) - self.loss_D_fake = self.criterionGAN(self.pred_fake, False) + fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1).data) + pred_fake = self.netD(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) # Real real_AB = torch.cat((self.real_A, self.real_B), 1) - self.pred_real = self.netD.forward(real_AB) - self.loss_D_real = self.criterionGAN(self.pred_real, True) + pred_real = self.netD(real_AB) + self.loss_D_real = self.criterionGAN(pred_real, True) # Combined loss self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 @@ -103,7 +103,7 @@ class Pix2PixModel(BaseModel): def backward_G(self): # First, G(A) should fake the discriminator fake_AB = torch.cat((self.real_A, self.fake_B), 1) - pred_fake = self.netD.forward(fake_AB) + pred_fake = self.netD(fake_AB) self.loss_G_GAN = self.criterionGAN(pred_fake, True) # Second, G(A) = B diff --git a/models/test_model.py b/models/test_model.py index 4af1fe1..2ae2812 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -34,7 +34,7 @@ class TestModel(BaseModel): def test(self): self.real_A = Variable(self.input_A) - self.fake_B = self.netG.forward(self.real_A) + self.fake_B = self.netG(self.real_A) # get image paths def get_image_paths(self): diff --git a/options/base_options.py b/options/base_options.py index e89f144..13466bf 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -3,9 +3,10 @@ import os from util import util import torch + class BaseOptions(): def __init__(self): - self.parser = argparse.ArgumentParser() + self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.initialized = False def initialize(self): @@ -33,12 +34,11 @@ class BaseOptions(): self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') - self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') - self.parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]') + self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') self.initialized = True diff --git a/options/train_options.py b/options/train_options.py index 32120ec..603d76a 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -5,6 +5,8 @@ class TrainOptions(BaseOptions): def initialize(self): BaseOptions.initialize(self) self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') + self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') + self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') @@ -23,6 +25,6 @@ class TrainOptions(BaseOptions): self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') - self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') + self.parser.add_argument('--identity', type=float, default=0.5, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') self.isTrain = True @@ -20,13 +20,15 @@ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): for i, data in enumerate(dataset): iter_start_time = time.time() + visualizer.reset() total_steps += opt.batchSize epoch_iter += opt.batchSize model.set_input(data) model.optimize_parameters() if total_steps % opt.display_freq == 0: - visualizer.display_current_results(model.get_current_visuals(), epoch) + save_result = total_steps % opt.update_html_freq == 0 + visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) if total_steps % opt.print_freq == 0: errors = model.get_current_errors() diff --git a/util/image_pool.py b/util/image_pool.py index 152ef5b..ada1627 100644 --- a/util/image_pool.py +++ b/util/image_pool.py @@ -2,6 +2,8 @@ import random import numpy as np import torch from torch.autograd import Variable + + class ImagePool(): def __init__(self, pool_size): self.pool_size = pool_size @@ -11,9 +13,9 @@ class ImagePool(): def query(self, images): if self.pool_size == 0: - return images + return Variable(images) return_images = [] - for image in images.data: + for image in images: image = torch.unsqueeze(image, 0) if self.num_imgs < self.pool_size: self.num_imgs = self.num_imgs + 1 diff --git a/util/visualizer.py b/util/visualizer.py index 02a36b7..e6e7cba 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -4,7 +4,8 @@ import ntpath import time from . import util from . import html -from pdb import set_trace as st + + class Visualizer(): def __init__(self, opt): # self.opt = opt @@ -12,10 +13,11 @@ class Visualizer(): self.use_html = opt.isTrain and not opt.no_html self.win_size = opt.display_winsize self.name = opt.name + self.opt = opt + self.saved = False if self.display_id > 0: import visdom - self.vis = visdom.Visdom(port = opt.display_port) - self.display_single_pane_ncols = opt.display_single_pane_ncols + self.vis = visdom.Visdom(port=opt.display_port) if self.use_html: self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') @@ -27,16 +29,19 @@ class Visualizer(): now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) + def reset(self): + self.saved = False + # |visuals|: dictionary of images to display or save - def display_current_results(self, visuals, epoch): - if self.display_id > 0: # show images in the browser - if self.display_single_pane_ncols > 0: + def display_current_results(self, visuals, epoch, save_result): + if self.display_id > 0: # show images in the browser + ncols = self.opt.display_single_pane_ncols + if ncols > 0: h, w = next(iter(visuals.values())).shape[:2] table_css = """<style> - table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center} - table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black} -</style>""" % (w, h) - ncols = self.display_single_pane_ncols + table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center} + table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black} + </style>""" % (w, h) title = self.name label_html = '' label_html_row = '' @@ -61,16 +66,17 @@ class Visualizer(): self.vis.images(images, nrow=ncols, win=self.display_id + 1, padding=2, opts=dict(title=title + ' images')) label_html = '<table>%s</table>' % label_html - self.vis.text(table_css + label_html, win = self.display_id + 2, + self.vis.text(table_css + label_html, win=self.display_id + 2, opts=dict(title=title + ' labels')) else: idx = 1 for label, image_numpy in visuals.items(): - self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), - win=self.display_id + idx) + self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), + win=self.display_id + idx) idx += 1 - if self.use_html: # save images to a html file + if self.use_html and (save_result or not self.saved): # save images to a html file + self.saved = True for label, image_numpy in visuals.items(): img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) util.save_image(image_numpy, img_path) @@ -93,11 +99,11 @@ class Visualizer(): # errors: dictionary of error labels and values def plot_current_errors(self, epoch, counter_ratio, opt, errors): if not hasattr(self, 'plot_data'): - self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())} + self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())} self.plot_data['X'].append(epoch + counter_ratio) self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) self.vis.line( - X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1), + X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), Y=np.array(self.plot_data['Y']), opts={ 'title': self.name + ' loss over time', |
