diff options
Diffstat (limited to 'models/cycle_gan_model.py')
| -rw-r--r-- | models/cycle_gan_model.py | 222 |
1 files changed, 222 insertions, 0 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py new file mode 100644 index 0000000..c3b5b72 --- /dev/null +++ b/models/cycle_gan_model.py @@ -0,0 +1,222 @@ +import numpy as np +import torch +import os +from collections import OrderedDict +from pdb import set_trace as st +from torch.autograd import Variable +import itertools +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks +import sys + +class CycleGANModel(BaseModel): + def name(self): + return 'CycleGANModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + nb = opt.batchSize + size = opt.fineSize + self.input_A = self.Tensor(nb, opt.input_nc, size, size) + self.input_B = self.Tensor(nb, opt.output_nc, size, size) + + # load/define networks + # The naming conversion is different from those used in the paper + # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) + + self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, + opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids) + self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, + opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids) + + if self.isTrain: + use_sigmoid = opt.no_lsgan + self.netD_A = networks.define_D(opt.output_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, use_sigmoid, self.gpu_ids) + self.netD_B = networks.define_D(opt.input_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, use_sigmoid, self.gpu_ids) + if not self.isTrain or opt.continue_train: + which_epoch = opt.which_epoch + self.load_network(self.netG_A, 'G_A', which_epoch) + self.load_network(self.netG_B, 'G_B', which_epoch) + if self.isTrain: + self.load_network(self.netD_A, 'D_A', which_epoch) + self.load_network(self.netD_B, 'D_B', which_epoch) + + if self.isTrain: + self.old_lr = opt.lr + self.fake_A_pool = ImagePool(opt.pool_size) + self.fake_B_pool = ImagePool(opt.pool_size) + # define loss functions + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + self.criterionCycle = torch.nn.L1Loss() + self.criterionIdt = torch.nn.L1Loss() + # initialize optimizers + self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG_A) + networks.print_network(self.netG_B) + networks.print_network(self.netD_A) + networks.print_network(self.netD_B) + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction is 'AtoB' + input_A = input['A' if AtoB else 'B'] + input_B = input['B' if AtoB else 'A'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.input_B.resize_(input_B.size()).copy_(input_B) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + self.real_A = Variable(self.input_A) + self.real_B = Variable(self.input_B) + + def test(self): + self.real_A = Variable(self.input_A, volatile=True) + self.fake_B = self.netG_A.forward(self.real_A) + self.rec_A = self.netG_B.forward(self.fake_B) + + self.real_B = Variable(self.input_B, volatile=True) + self.fake_A = self.netG_B.forward(self.real_B) + self.rec_B = self.netG_A.forward(self.fake_A) + + #get image paths + def get_image_paths(self): + return self.image_paths + + def backward_D_basic(self, netD, real, fake): + # Real + pred_real = netD.forward(real) + loss_D_real = self.criterionGAN(pred_real, True) + # Fake + pred_fake = netD.forward(fake.detach()) + loss_D_fake = self.criterionGAN(pred_fake, False) + # Combined loss + loss_D = (loss_D_real + loss_D_fake) * 0.5 + # backward + loss_D.backward() + return loss_D + + def backward_D_A(self): + fake_B = self.fake_B_pool.query(self.fake_B) + self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + + def backward_D_B(self): + fake_A = self.fake_A_pool.query(self.fake_A) + self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + + def backward_G(self): + lambda_idt = self.opt.identity + lambda_A = self.opt.lambda_A + lambda_B = self.opt.lambda_B + # Identity loss + if lambda_idt > 0: + # G_A should be identity if real_B is fed. + self.idt_A = self.netG_A.forward(self.real_B) + self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt + # G_B should be identity if real_A is fed. + self.idt_B = self.netG_B.forward(self.real_A) + self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt + else: + self.loss_idt_A = 0 + self.loss_idt_B = 0 + + # GAN loss + # D_A(G_A(A)) + self.fake_B = self.netG_A.forward(self.real_A) + pred_fake = self.netD_A.forward(self.fake_B) + self.loss_G_A = self.criterionGAN(pred_fake, True) + # D_B(G_B(B)) + self.fake_A = self.netG_B.forward(self.real_B) + pred_fake = self.netD_B.forward(self.fake_A) + self.loss_G_B = self.criterionGAN(pred_fake, True) + # Forward cycle loss + self.rec_A = self.netG_B.forward(self.fake_B) + self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A + # Backward cycle loss + self.rec_B = self.netG_A.forward(self.fake_A) + self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B + # combined loss + self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B + self.loss_G.backward() + + def optimize_parameters(self): + # forward + self.forward() + # G_A and G_B + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() + # D_A + self.optimizer_D_A.zero_grad() + self.backward_D_A() + self.optimizer_D_A.step() + # D_B + self.optimizer_D_B.zero_grad() + self.backward_D_B() + self.optimizer_D_B.step() + + + def get_current_errors(self): + D_A = self.loss_D_A.data[0] + G_A = self.loss_G_A.data[0] + Cyc_A = self.loss_cycle_A.data[0] + D_B = self.loss_D_B.data[0] + G_B = self.loss_G_B.data[0] + Cyc_B = self.loss_cycle_B.data[0] + if self.opt.identity > 0.0: + idt_A = self.loss_idt_A.data[0] + idt_B = self.loss_idt_B.data[0] + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), + ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) + else: + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), + ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + rec_A = util.tensor2im(self.rec_A.data) + real_B = util.tensor2im(self.real_B.data) + fake_A = util.tensor2im(self.fake_A.data) + rec_B = util.tensor2im(self.rec_B.data) + if self.opt.identity > 0.0: + idt_A = util.tensor2im(self.idt_A.data) + idt_B = util.tensor2im(self.idt_B.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) + + def save(self, label): + use_gpu = self.gpu_ids is not None + self.save_network(self.netG_A, 'G_A', label, use_gpu) + self.save_network(self.netD_A, 'D_A', label, use_gpu) + self.save_network(self.netG_B, 'G_B', label, use_gpu) + self.save_network(self.netD_B, 'D_B', label, use_gpu) + + def update_learning_rate(self): + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D_A.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_D_B.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr |
