diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-04-18 03:38:47 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-04-18 03:38:47 -0700 |
| commit | c99ce7c4e781712e0252c6127ad1a4e8021cc489 (patch) | |
| tree | ba99dfd56a47036d9c1f18620abf4efc248839ab /models | |
first commit
Diffstat (limited to 'models')
| -rw-r--r-- | models/__init__.py | 0 | ||||
| -rw-r--r-- | models/base_model.py | 56 | ||||
| -rw-r--r-- | models/cycle_gan_model.py | 222 | ||||
| -rw-r--r-- | models/models.py | 15 | ||||
| -rw-r--r-- | models/networks.py | 288 | ||||
| -rw-r--r-- | models/pix2pix_model.py | 147 |
6 files changed, 728 insertions, 0 deletions
diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/models/__init__.py diff --git a/models/base_model.py b/models/base_model.py new file mode 100644 index 0000000..0ea83d8 --- /dev/null +++ b/models/base_model.py @@ -0,0 +1,56 @@ +import os +import torch +from pdb import set_trace as st + +class BaseModel(): + def name(self): + return 'BaseModel' + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + + def set_input(self, input): + self.input = input + + def forward(self): + pass + + # used in test time, no backprop + def test(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, network_label, epoch_label, use_gpu): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(network.cpu().state_dict(), save_path) + if use_gpu and torch.cuda.is_available(): + network.cuda() + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + network.load_state_dict(torch.load(save_path)) + + def update_learning_rate(): + pass diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py new file mode 100644 index 0000000..c3b5b72 --- /dev/null +++ b/models/cycle_gan_model.py @@ -0,0 +1,222 @@ +import numpy as np +import torch +import os +from collections import OrderedDict +from pdb import set_trace as st +from torch.autograd import Variable +import itertools +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks +import sys + +class CycleGANModel(BaseModel): + def name(self): + return 'CycleGANModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + + nb = opt.batchSize + size = opt.fineSize + self.input_A = self.Tensor(nb, opt.input_nc, size, size) + self.input_B = self.Tensor(nb, opt.output_nc, size, size) + + # load/define networks + # The naming conversion is different from those used in the paper + # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) + + self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, + opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids) + self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, + opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids) + + if self.isTrain: + use_sigmoid = opt.no_lsgan + self.netD_A = networks.define_D(opt.output_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, use_sigmoid, self.gpu_ids) + self.netD_B = networks.define_D(opt.input_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, use_sigmoid, self.gpu_ids) + if not self.isTrain or opt.continue_train: + which_epoch = opt.which_epoch + self.load_network(self.netG_A, 'G_A', which_epoch) + self.load_network(self.netG_B, 'G_B', which_epoch) + if self.isTrain: + self.load_network(self.netD_A, 'D_A', which_epoch) + self.load_network(self.netD_B, 'D_B', which_epoch) + + if self.isTrain: + self.old_lr = opt.lr + self.fake_A_pool = ImagePool(opt.pool_size) + self.fake_B_pool = ImagePool(opt.pool_size) + # define loss functions + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + self.criterionCycle = torch.nn.L1Loss() + self.criterionIdt = torch.nn.L1Loss() + # initialize optimizers + self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG_A) + networks.print_network(self.netG_B) + networks.print_network(self.netD_A) + networks.print_network(self.netD_B) + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction is 'AtoB' + input_A = input['A' if AtoB else 'B'] + input_B = input['B' if AtoB else 'A'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.input_B.resize_(input_B.size()).copy_(input_B) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + self.real_A = Variable(self.input_A) + self.real_B = Variable(self.input_B) + + def test(self): + self.real_A = Variable(self.input_A, volatile=True) + self.fake_B = self.netG_A.forward(self.real_A) + self.rec_A = self.netG_B.forward(self.fake_B) + + self.real_B = Variable(self.input_B, volatile=True) + self.fake_A = self.netG_B.forward(self.real_B) + self.rec_B = self.netG_A.forward(self.fake_A) + + #get image paths + def get_image_paths(self): + return self.image_paths + + def backward_D_basic(self, netD, real, fake): + # Real + pred_real = netD.forward(real) + loss_D_real = self.criterionGAN(pred_real, True) + # Fake + pred_fake = netD.forward(fake.detach()) + loss_D_fake = self.criterionGAN(pred_fake, False) + # Combined loss + loss_D = (loss_D_real + loss_D_fake) * 0.5 + # backward + loss_D.backward() + return loss_D + + def backward_D_A(self): + fake_B = self.fake_B_pool.query(self.fake_B) + self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) + + def backward_D_B(self): + fake_A = self.fake_A_pool.query(self.fake_A) + self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) + + def backward_G(self): + lambda_idt = self.opt.identity + lambda_A = self.opt.lambda_A + lambda_B = self.opt.lambda_B + # Identity loss + if lambda_idt > 0: + # G_A should be identity if real_B is fed. + self.idt_A = self.netG_A.forward(self.real_B) + self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt + # G_B should be identity if real_A is fed. + self.idt_B = self.netG_B.forward(self.real_A) + self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt + else: + self.loss_idt_A = 0 + self.loss_idt_B = 0 + + # GAN loss + # D_A(G_A(A)) + self.fake_B = self.netG_A.forward(self.real_A) + pred_fake = self.netD_A.forward(self.fake_B) + self.loss_G_A = self.criterionGAN(pred_fake, True) + # D_B(G_B(B)) + self.fake_A = self.netG_B.forward(self.real_B) + pred_fake = self.netD_B.forward(self.fake_A) + self.loss_G_B = self.criterionGAN(pred_fake, True) + # Forward cycle loss + self.rec_A = self.netG_B.forward(self.fake_B) + self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A + # Backward cycle loss + self.rec_B = self.netG_A.forward(self.fake_A) + self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B + # combined loss + self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B + self.loss_G.backward() + + def optimize_parameters(self): + # forward + self.forward() + # G_A and G_B + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() + # D_A + self.optimizer_D_A.zero_grad() + self.backward_D_A() + self.optimizer_D_A.step() + # D_B + self.optimizer_D_B.zero_grad() + self.backward_D_B() + self.optimizer_D_B.step() + + + def get_current_errors(self): + D_A = self.loss_D_A.data[0] + G_A = self.loss_G_A.data[0] + Cyc_A = self.loss_cycle_A.data[0] + D_B = self.loss_D_B.data[0] + G_B = self.loss_G_B.data[0] + Cyc_B = self.loss_cycle_B.data[0] + if self.opt.identity > 0.0: + idt_A = self.loss_idt_A.data[0] + idt_B = self.loss_idt_B.data[0] + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A), + ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)]) + else: + return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), + ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)]) + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + rec_A = util.tensor2im(self.rec_A.data) + real_B = util.tensor2im(self.real_B.data) + fake_A = util.tensor2im(self.fake_A.data) + rec_B = util.tensor2im(self.rec_B.data) + if self.opt.identity > 0.0: + idt_A = util.tensor2im(self.idt_A.data) + idt_B = util.tensor2im(self.idt_B.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)]) + else: + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), + ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)]) + + def save(self, label): + use_gpu = self.gpu_ids is not None + self.save_network(self.netG_A, 'G_A', label, use_gpu) + self.save_network(self.netD_A, 'D_A', label, use_gpu) + self.save_network(self.netG_B, 'G_B', label, use_gpu) + self.save_network(self.netD_B, 'D_B', label, use_gpu) + + def update_learning_rate(self): + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D_A.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_D_B.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr diff --git a/models/models.py b/models/models.py new file mode 100644 index 0000000..7e790d0 --- /dev/null +++ b/models/models.py @@ -0,0 +1,15 @@ + +def create_model(opt): + model = None + print(opt.model) + if opt.model == 'cycle_gan': + from .cycle_gan_model import CycleGANModel + assert(opt.align_data == False) + model = CycleGANModel() + if opt.model == 'pix2pix': + from .pix2pix_model import Pix2PixModel + assert(opt.align_data == True) + model = Pix2PixModel() + model.initialize(opt) + print("model [%s] was created" % (model.name())) + return model diff --git a/models/networks.py b/models/networks.py new file mode 100644 index 0000000..d41bd0e --- /dev/null +++ b/models/networks.py @@ -0,0 +1,288 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +from pdb import set_trace as st + +############################################################################### +# Functions +############################################################################### + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]): + netG = None + use_gpu = len(gpu_ids) > 0 + if norm == 'batch': + norm_layer = nn.BatchNorm2d + elif norm == 'instance': + norm_layer = InstanceNormalization + else: + print('normalization layer [%s] is not found' % norm) + + assert(torch.cuda.is_available() == use_gpu) + if which_model_netG == 'resnet_9blocks': + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids) + elif which_model_netG == 'resnet_6blocks': + netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=6, gpu_ids=gpu_ids) + elif which_model_netG == 'unet': + netG = UnetGenerator(input_nc, output_nc, ngf, norm_layer, gpu_ids=gpu_ids) + else: + print('Generator model name [%s] is not recognized' % which_model_netG) + if use_gpu: + netG.cuda() + netG.apply(weights_init) + return netG + + +def define_D(input_nc, ndf, which_model_netD, + n_layers_D=3, use_sigmoid=False, gpu_ids=[]): + netD = None + use_gpu = len(gpu_ids) > 0 + assert(torch.cuda.is_available() == use_gpu) + if which_model_netD == 'basic': + netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'n_layers': + netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid, gpu_ids=gpu_ids) + else: + print('Discriminator model name [%s] is not recognized' % + which_model_netD) + if use_gpu: + netD.cuda() + netD.apply(weights_init) + return netD + + +def print_network(net): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print('Total number of parameters: %d' % num_params) + + +############################################################################## +# Classes +############################################################################## + + +# Defines the GAN loss used in LSGAN. +# It is basically same as MSELoss, but it abstracts away the need to create +# the target label tensor that has the same size as the input +class GANLoss(nn.Module): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, + tensor=torch.FloatTensor): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_var = None + self.fake_label_var = None + self.Tensor = tensor + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + target_tensor = None + if target_is_real: + create_label = ((self.real_label_var is None) or + (self.real_label_var.numel() != input.numel())) + if create_label: + real_tensor = self.Tensor(input.size()).fill_(self.real_label) + self.real_label_var = Variable(real_tensor, requires_grad=False) + target_tensor = self.real_label_var + else: + create_label = ((self.fake_label_var is None) or + (self.fake_label_var.numel() != input.numel())) + if create_label: + fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) + self.fake_label_var = Variable(fake_tensor, requires_grad=False) + target_tensor = self.fake_label_var + return target_tensor + + def __call__(self, input, target_is_real): + target_tensor = self.get_target_tensor(input, target_is_real) + return self.loss(input, target_tensor) + + +# Defines the generator that consists of Resnet blocks between a few +# downsampling/upsampling operations. +# Code and idea originally from Justin Johnson's architecture. +# https://github.com/jcjohnson/fast-neural-style/ +class ResnetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, n_blocks=6, gpu_ids=[]): + assert(n_blocks >= 0) + super(ResnetGenerator, self).__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.ngf = ngf + self.gpu_ids = gpu_ids + + model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3), + norm_layer(ngf), + nn.ReLU(True)] + + n_downsampling = 2 + for i in range(n_downsampling): + mult = 2**i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, + stride=2, padding=1), + norm_layer(ngf * mult * 2), + nn.ReLU(True)] + + mult = 2**n_downsampling + for i in range(n_blocks): + model += [Resnet_block(ngf * mult, 'zero', norm_layer=norm_layer)] + + for i in range(n_downsampling): + mult = 2**(n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=1, output_padding=1), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + + model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=3)] + model += [nn.Tanh()] + + self.model = nn.Sequential(*model) + + def forward(self, input): + if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + else: + return self.model(input) + + + +# Define a resnet block +class Resnet_block(nn.Module): + def __init__(self, dim, padding_type, norm_layer): + super(Resnet_block, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer) + + def build_conv_block(self, dim, padding_type, norm_layer): + conv_block = [] + p = 0 + # TODO: support padding types + assert(padding_type == 'zero') + p = 1 + + # TODO: InstanceNorm + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim), + nn.ReLU(True)] + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), + norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + out = x + self.conv_block(x) + return out + + +# Defines the Unet geneator. +class UnetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, gpu_ids=[]): + super(UnetGenerator, self).__init__() + self.input_nc = input_nc + self.output_nc = output_nc + self.ngf = ngf + self.gpu_ids = gpu_ids + + def forward(self, input): + if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + else: + return self.model(input) + + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]): + super(NLayerDiscriminator, self).__init__() + self.gpu_ids = gpu_ids + + kw = 4 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=2), + nn.LeakyReLU(0.2, True) + ] + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=2, padding=2), + # TODO: use InstanceNorm + nn.BatchNorm2d(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=1, stride=2, padding=2), + # TODO: useInstanceNorm + nn.BatchNorm2d(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=1)] + sequence += [nn.Sigmoid()] + + self.model = nn.Sequential(*sequence) + + def forward(self, input): + if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids: + return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + else: + return self.model(input) + +# Instance Normalization layer from +# https://github.com/darkstar112358/fast-neural-style + +class InstanceNormalization(torch.nn.Module): + """InstanceNormalization + Improves convergence of neural-style. + ref: https://arxiv.org/pdf/1607.08022.pdf + """ + + def __init__(self, dim, eps=1e-5): + super(InstanceNormalization, self).__init__() + self.weight = nn.Parameter(torch.FloatTensor(dim)) + self.bias = nn.Parameter(torch.FloatTensor(dim)) + self.eps = eps + self._reset_parameters() + + def _reset_parameters(self): + self.weight.data.uniform_() + self.bias.data.zero_() + + def forward(self, x): + n = x.size(2) * x.size(3) + t = x.view(x.size(0), x.size(1), n) + mean = torch.mean(t, 2).unsqueeze(2).expand_as(x) + # Calculate the biased var. torch.var returns unbiased var + var = torch.var(t, 2).unsqueeze(2).expand_as(x) * ((n - 1) / float(n)) + scale_broadcast = self.weight.unsqueeze(1).unsqueeze(1).unsqueeze(0) + scale_broadcast = scale_broadcast.expand_as(x) + shift_broadcast = self.bias.unsqueeze(1).unsqueeze(1).unsqueeze(0) + shift_broadcast = shift_broadcast.expand_as(x) + out = (x - mean) / torch.sqrt(var + self.eps) + out = out * scale_broadcast + shift_broadcast + return out diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py new file mode 100644 index 0000000..1d89b29 --- /dev/null +++ b/models/pix2pix_model.py @@ -0,0 +1,147 @@ +import numpy as np +import torch +import os +from collections import OrderedDict +from pdb import set_trace as st +from torch.autograd import Variable +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks + + +class Pix2PixModel(BaseModel): + def name(self): + return 'Pix2PixModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + self.isTrain = opt.isTrain + # define tensors + self.input_A = self.Tensor(opt.batchSize, opt.input_nc, + opt.fineSize, opt.fineSize) + self.input_B = self.Tensor(opt.batchSize, opt.output_nc, + opt.fineSize, opt.fineSize) + + # load/define networks + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, + opt.which_model_netG, opt.norm, self.gpu_ids) + if self.isTrain: + use_sigmoid = opt.no_lsgan + self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, use_sigmoid, self.gpu_ids) + if not self.isTrain or opt.continue_train: + self.load_network(self.netG, 'G', opt.which_epoch) + if self.isTrain: + self.load_network(self.netD, 'D', opt.which_epoch) + + if self.isTrain: + self.fake_AB_pool = ImagePool(opt.pool_size) + self.old_lr = opt.lr + # define loss functions + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + self.criterionL1 = torch.nn.L1Loss() + + # initialize optimizers + self.optimizer_G = torch.optim.Adam(self.netG.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG) + networks.print_network(self.netD) + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction is 'AtoB' + input_A = input['A' if AtoB else 'B'] + input_B = input['B' if AtoB else 'A'] + self.input_A.resize_(input_A.size()).copy_(input_A) + self.input_B.resize_(input_B.size()).copy_(input_B) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + self.real_A = Variable(self.input_A) + self.fake_B = self.netG.forward(self.real_A) + self.real_B = Variable(self.input_B) + + # no backprop gradients + def test(self): + self.real_A = Variable(self.input_A, volatile=True) + self.fake_B = self.netG.forward(self.real_A) + self.real_B = Variable(self.input_B, volatile=True) + + #get image paths + def get_image_paths(self): + return self.image_paths + + def backward_D(self): + # Fake + # stop backprop to the generator by detaching fake_B + fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) + self.pred_fake = self.netD.forward(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(self.pred_fake, False) + + # Real + real_AB = torch.cat((self.real_A, self.real_B), 1)#.detach() + self.pred_real = self.netD.forward(real_AB) + self.loss_D_real = self.criterionGAN(self.pred_real, True) + + # Combined loss + self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 + + self.loss_D.backward() + + def backward_G(self): + # First, G(A) should fake the discriminator + fake_AB = torch.cat((self.real_A, self.fake_B), 1) + pred_fake = self.netD.forward(fake_AB) + self.loss_G_GAN = self.criterionGAN(pred_fake, True) + + # Second, G(A) = B + self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A + + self.loss_G = self.loss_G_GAN + self.loss_G_L1 + + self.loss_G.backward() + + def optimize_parameters(self): + self.forward() + + self.optimizer_D.zero_grad() + self.backward_D() + self.optimizer_D.step() + + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() + + def get_current_errors(self): + return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]), + ('G_L1', self.loss_G_L1.data[0]), + ('D_real', self.loss_D_real.data[0]), + ('D_fake', self.loss_D_fake.data[0]) + ]) + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + real_B = util.tensor2im(self.real_B.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)]) + + def save(self, label): + use_gpu = self.gpu_ids is not None + self.save_network(self.netG, 'G', label, use_gpu) + self.save_network(self.netD, 'D', label, use_gpu) + + def update_learning_rate(self): + lrd = self.opt.lr / self.opt.niter_decay + lr = self.old_lr - lrd + for param_group in self.optimizer_D.param_groups: + param_group['lr'] = lr + for param_group in self.optimizer_G.param_groups: + param_group['lr'] = lr + print('update learning rate: %f -> %f' % (self.old_lr, lr)) + self.old_lr = lr |
