diff options
Diffstat (limited to 'models/pix2pix_model.py')
| -rw-r--r-- | models/pix2pix_model.py | 147 |
1 files changed, 147 insertions, 0 deletions
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py new file mode 100644 index 0000000..1d89b29 --- /dev/null +++ b/models/pix2pix_model.py @@ -0,0 +1,147 @@ +import numpy as np +import torch +import os +from collections import OrderedDict +from pdb import set_trace as st +from torch.autograd import Variable +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks + + +class Pix2PixModel(BaseModel): + def name(self): + return 'Pix2PixModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + self.isTrain = opt.isTrain + # define tensors + self.input_A = self.Tensor(opt.batchSize, opt.input_nc, + opt.fineSize, opt.fineSize) + self.input_B = self.Tensor(opt.batchSize, opt.output_nc, + opt.fineSize, opt.fineSize) + + # load/define networks + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, + opt.which_model_netG, opt.norm, self.gpu_ids) + if self.isTrain: + use_sigmoid = opt.no_lsgan + self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, use_sigmoid, self.gpu_ids) + if not self.isTrain or opt.continue_train: + self.load_network(self.netG, 'G', opt.which_epoch) + if self.isTrain: + self.load_network(self.netD, 'D', opt.which_epoch) + + if self.isTrain: + self.fake_AB_pool = ImagePool(opt.pool_size) + self.old_lr = opt.lr + # define loss functions + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + self.criterionL1 = torch.nn.L1Loss() + + # initialize optimizers + self.optimizer_G = torch.optim.Adam(self.netG.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG) + networks.print_network(self.netD) + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction is 'AtoB' + input_A = input['A' if AtoB else 'B'] + input_B = input['B' if AtoB else 'A'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.input_B.resize_(input_B.size()).copy_(input_B) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + self.real_A = Variable(self.input_A) + self.fake_B = self.netG.forward(self.real_A) + self.real_B = Variable(self.input_B) + + # no backprop gradients + def test(self): + self.real_A = Variable(self.input_A, volatile=True) + self.fake_B = self.netG.forward(self.real_A) + self.real_B = Variable(self.input_B, volatile=True) + + #get image paths + def get_image_paths(self): + return self.image_paths + + def backward_D(self): + # Fake + # stop backprop to the generator by detaching fake_B + fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) + self.pred_fake = self.netD.forward(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(self.pred_fake, False) + + # Real + real_AB = torch.cat((self.real_A, self.real_B), 1)#.detach() + self.pred_real = self.netD.forward(real_AB) + self.loss_D_real = self.criterionGAN(self.pred_real, True) + + # Combined loss + self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 + + self.loss_D.backward() + + def backward_G(self): + # First, G(A) should fake the discriminator + fake_AB = torch.cat((self.real_A, self.fake_B), 1) + pred_fake = self.netD.forward(fake_AB) + self.loss_G_GAN = self.criterionGAN(pred_fake, True) + + # Second, G(A) = B + self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A + + self.loss_G = self.loss_G_GAN + self.loss_G_L1 + + self.loss_G.backward() + + def optimize_parameters(self): + self.forward() + + self.optimizer_D.zero_grad() + self.backward_D() + self.optimizer_D.step() + + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() + + def get_current_errors(self): + return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), + ('G_L1', self.loss_G_L1.data[0]), + ('D_real', self.loss_D_real.data[0]), + ('D_fake', self.loss_D_fake.data[0]) + ]) + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + real_B = util.tensor2im(self.real_B.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) + + def save(self, label): + use_gpu = self.gpu_ids is not None + self.save_network(self.netG, 'G', label, use_gpu) + self.save_network(self.netD, 'D', label, use_gpu) + + def update_learning_rate(self): + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr |
