summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-04-18 03:38:47 -0700
committerjunyanz <junyanz@berkeley.edu>2017-04-18 03:38:47 -0700
commitc99ce7c4e781712e0252c6127ad1a4e8021cc489 (patch)
treeba99dfd56a47036d9c1f18620abf4efc248839ab /models
first commit
Diffstat (limited to 'models')
-rw-r--r--models/__init__.py0
-rw-r--r--models/base_model.py56
-rw-r--r--models/cycle_gan_model.py222
-rw-r--r--models/models.py15
-rw-r--r--models/networks.py288
-rw-r--r--models/pix2pix_model.py147
6 files changed, 728 insertions, 0 deletions
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/models/__init__.py
diff --git a/models/base_model.py b/models/base_model.py
new file mode 100644
index 0000000..0ea83d8
--- /dev/null
+++ b/models/base_model.py
@@ -0,0 +1,56 @@
+import os
+import torch
+from pdb import set_trace as st
+
+class BaseModel():
+ def name(self):
+ return 'BaseModel'
+
+ def initialize(self, opt):
+ self.opt = opt
+ self.gpu_ids = opt.gpu_ids
+ self.isTrain = opt.isTrain
+ self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
+
+ def set_input(self, input):
+ self.input = input
+
+ def forward(self):
+ pass
+
+ # used in test time, no backprop
+ def test(self):
+ pass
+
+ def get_image_paths(self):
+ pass
+
+ def optimize_parameters(self):
+ pass
+
+ def get_current_visuals(self):
+ return self.input
+
+ def get_current_errors(self):
+ return {}
+
+ def save(self, label):
+ pass
+
+ # helper saving function that can be used by subclasses
+ def save_network(self, network, network_label, epoch_label, use_gpu):
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
+ save_path = os.path.join(self.save_dir, save_filename)
+ torch.save(network.cpu().state_dict(), save_path)
+ if use_gpu and torch.cuda.is_available():
+ network.cuda()
+
+ # helper loading function that can be used by subclasses
+ def load_network(self, network, network_label, epoch_label):
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
+ save_path = os.path.join(self.save_dir, save_filename)
+ network.load_state_dict(torch.load(save_path))
+
+ def update_learning_rate():
+ pass
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
new file mode 100644
index 0000000..c3b5b72
--- /dev/null
+++ b/models/cycle_gan_model.py
@@ -0,0 +1,222 @@
+import numpy as np
+import torch
+import os
+from collections import OrderedDict
+from pdb import set_trace as st
+from torch.autograd import Variable
+import itertools
+import util.util as util
+from util.image_pool import ImagePool
+from .base_model import BaseModel
+from . import networks
+import sys
+
+class CycleGANModel(BaseModel):
+ def name(self):
+ return 'CycleGANModel'
+
+ def initialize(self, opt):
+ BaseModel.initialize(self, opt)
+
+ nb = opt.batchSize
+ size = opt.fineSize
+ self.input_A = self.Tensor(nb, opt.input_nc, size, size)
+ self.input_B = self.Tensor(nb, opt.output_nc, size, size)
+
+ # load/define networks
+ # The naming conversion is different from those used in the paper
+ # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
+
+ self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
+ opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids)
+ self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
+ opt.ngf, opt.which_model_netG, opt.norm, self.gpu_ids)
+
+ if self.isTrain:
+ use_sigmoid = opt.no_lsgan
+ self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
+ opt.which_model_netD,
+ opt.n_layers_D, use_sigmoid, self.gpu_ids)
+ self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
+ opt.which_model_netD,
+ opt.n_layers_D, use_sigmoid, self.gpu_ids)
+ if not self.isTrain or opt.continue_train:
+ which_epoch = opt.which_epoch
+ self.load_network(self.netG_A, 'G_A', which_epoch)
+ self.load_network(self.netG_B, 'G_B', which_epoch)
+ if self.isTrain:
+ self.load_network(self.netD_A, 'D_A', which_epoch)
+ self.load_network(self.netD_B, 'D_B', which_epoch)
+
+ if self.isTrain:
+ self.old_lr = opt.lr
+ self.fake_A_pool = ImagePool(opt.pool_size)
+ self.fake_B_pool = ImagePool(opt.pool_size)
+ # define loss functions
+ self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
+ self.criterionCycle = torch.nn.L1Loss()
+ self.criterionIdt = torch.nn.L1Loss()
+ # initialize optimizers
+ self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
+ lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(),
+ lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(),
+ lr=opt.lr, betas=(opt.beta1, 0.999))
+
+ print('---------- Networks initialized -------------')
+ networks.print_network(self.netG_A)
+ networks.print_network(self.netG_B)
+ networks.print_network(self.netD_A)
+ networks.print_network(self.netD_B)
+ print('-----------------------------------------------')
+
+ def set_input(self, input):
+ AtoB = self.opt.which_direction is 'AtoB'
+ input_A = input['A' if AtoB else 'B']
+ input_B = input['B' if AtoB else 'A']
+ self.input_A.resize_(input_A.size()).copy_(input_A)
+ self.input_B.resize_(input_B.size()).copy_(input_B)
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
+
+ def forward(self):
+ self.real_A = Variable(self.input_A)
+ self.real_B = Variable(self.input_B)
+
+ def test(self):
+ self.real_A = Variable(self.input_A, volatile=True)
+ self.fake_B = self.netG_A.forward(self.real_A)
+ self.rec_A = self.netG_B.forward(self.fake_B)
+
+ self.real_B = Variable(self.input_B, volatile=True)
+ self.fake_A = self.netG_B.forward(self.real_B)
+ self.rec_B = self.netG_A.forward(self.fake_A)
+
+ #get image paths
+ def get_image_paths(self):
+ return self.image_paths
+
+ def backward_D_basic(self, netD, real, fake):
+ # Real
+ pred_real = netD.forward(real)
+ loss_D_real = self.criterionGAN(pred_real, True)
+ # Fake
+ pred_fake = netD.forward(fake.detach())
+ loss_D_fake = self.criterionGAN(pred_fake, False)
+ # Combined loss
+ loss_D = (loss_D_real + loss_D_fake) * 0.5
+ # backward
+ loss_D.backward()
+ return loss_D
+
+ def backward_D_A(self):
+ fake_B = self.fake_B_pool.query(self.fake_B)
+ self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
+
+ def backward_D_B(self):
+ fake_A = self.fake_A_pool.query(self.fake_A)
+ self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
+
+ def backward_G(self):
+ lambda_idt = self.opt.identity
+ lambda_A = self.opt.lambda_A
+ lambda_B = self.opt.lambda_B
+ # Identity loss
+ if lambda_idt > 0:
+ # G_A should be identity if real_B is fed.
+ self.idt_A = self.netG_A.forward(self.real_B)
+ self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
+ # G_B should be identity if real_A is fed.
+ self.idt_B = self.netG_B.forward(self.real_A)
+ self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
+ else:
+ self.loss_idt_A = 0
+ self.loss_idt_B = 0
+
+ # GAN loss
+ # D_A(G_A(A))
+ self.fake_B = self.netG_A.forward(self.real_A)
+ pred_fake = self.netD_A.forward(self.fake_B)
+ self.loss_G_A = self.criterionGAN(pred_fake, True)
+ # D_B(G_B(B))
+ self.fake_A = self.netG_B.forward(self.real_B)
+ pred_fake = self.netD_B.forward(self.fake_A)
+ self.loss_G_B = self.criterionGAN(pred_fake, True)
+ # Forward cycle loss
+ self.rec_A = self.netG_B.forward(self.fake_B)
+ self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
+ # Backward cycle loss
+ self.rec_B = self.netG_A.forward(self.fake_A)
+ self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
+ # combined loss
+ self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
+ self.loss_G.backward()
+
+ def optimize_parameters(self):
+ # forward
+ self.forward()
+ # G_A and G_B
+ self.optimizer_G.zero_grad()
+ self.backward_G()
+ self.optimizer_G.step()
+ # D_A
+ self.optimizer_D_A.zero_grad()
+ self.backward_D_A()
+ self.optimizer_D_A.step()
+ # D_B
+ self.optimizer_D_B.zero_grad()
+ self.backward_D_B()
+ self.optimizer_D_B.step()
+
+
+ def get_current_errors(self):
+ D_A = self.loss_D_A.data[0]
+ G_A = self.loss_G_A.data[0]
+ Cyc_A = self.loss_cycle_A.data[0]
+ D_B = self.loss_D_B.data[0]
+ G_B = self.loss_G_B.data[0]
+ Cyc_B = self.loss_cycle_B.data[0]
+ if self.opt.identity > 0.0:
+ idt_A = self.loss_idt_A.data[0]
+ idt_B = self.loss_idt_B.data[0]
+ return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A), ('idt_A', idt_A),
+ ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B), ('idt_B', idt_B)])
+ else:
+ return OrderedDict([('D_A', D_A), ('G_A', G_A), ('Cyc_A', Cyc_A),
+ ('D_B', D_B), ('G_B', G_B), ('Cyc_B', Cyc_B)])
+
+ def get_current_visuals(self):
+ real_A = util.tensor2im(self.real_A.data)
+ fake_B = util.tensor2im(self.fake_B.data)
+ rec_A = util.tensor2im(self.rec_A.data)
+ real_B = util.tensor2im(self.real_B.data)
+ fake_A = util.tensor2im(self.fake_A.data)
+ rec_B = util.tensor2im(self.rec_B.data)
+ if self.opt.identity > 0.0:
+ idt_A = util.tensor2im(self.idt_A.data)
+ idt_B = util.tensor2im(self.idt_B.data)
+ return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('idt_B', idt_B),
+ ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B), ('idt_A', idt_A)])
+ else:
+ return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
+ ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
+
+ def save(self, label):
+ use_gpu = self.gpu_ids is not None
+ self.save_network(self.netG_A, 'G_A', label, use_gpu)
+ self.save_network(self.netD_A, 'D_A', label, use_gpu)
+ self.save_network(self.netG_B, 'G_B', label, use_gpu)
+ self.save_network(self.netD_B, 'D_B', label, use_gpu)
+
+ def update_learning_rate(self):
+ lrd = self.opt.lr / self.opt.niter_decay
+ lr = self.old_lr - lrd
+ for param_group in self.optimizer_D_A.param_groups:
+ param_group['lr'] = lr
+ for param_group in self.optimizer_D_B.param_groups:
+ param_group['lr'] = lr
+ for param_group in self.optimizer_G.param_groups:
+ param_group['lr'] = lr
+
+ print('update learning rate: %f -> %f' % (self.old_lr, lr))
+ self.old_lr = lr
diff --git a/models/models.py b/models/models.py
new file mode 100644
index 0000000..7e790d0
--- /dev/null
+++ b/models/models.py
@@ -0,0 +1,15 @@
+
+def create_model(opt):
+ model = None
+ print(opt.model)
+ if opt.model == 'cycle_gan':
+ from .cycle_gan_model import CycleGANModel
+ assert(opt.align_data == False)
+ model = CycleGANModel()
+ if opt.model == 'pix2pix':
+ from .pix2pix_model import Pix2PixModel
+ assert(opt.align_data == True)
+ model = Pix2PixModel()
+ model.initialize(opt)
+ print("model [%s] was created" % (model.name()))
+ return model
diff --git a/models/networks.py b/models/networks.py
new file mode 100644
index 0000000..d41bd0e
--- /dev/null
+++ b/models/networks.py
@@ -0,0 +1,288 @@
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+from pdb import set_trace as st
+
+###############################################################################
+# Functions
+###############################################################################
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ m.weight.data.normal_(0.0, 0.02)
+ elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1:
+ m.weight.data.normal_(1.0, 0.02)
+ m.bias.data.fill_(0)
+
+
+def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]):
+ netG = None
+ use_gpu = len(gpu_ids) > 0
+ if norm == 'batch':
+ norm_layer = nn.BatchNorm2d
+ elif norm == 'instance':
+ norm_layer = InstanceNormalization
+ else:
+ print('normalization layer [%s] is not found' % norm)
+
+ assert(torch.cuda.is_available() == use_gpu)
+ if which_model_netG == 'resnet_9blocks':
+ netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids)
+ elif which_model_netG == 'resnet_6blocks':
+ netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=6, gpu_ids=gpu_ids)
+ elif which_model_netG == 'unet':
+ netG = UnetGenerator(input_nc, output_nc, ngf, norm_layer, gpu_ids=gpu_ids)
+ else:
+ print('Generator model name [%s] is not recognized' % which_model_netG)
+ if use_gpu:
+ netG.cuda()
+ netG.apply(weights_init)
+ return netG
+
+
+def define_D(input_nc, ndf, which_model_netD,
+ n_layers_D=3, use_sigmoid=False, gpu_ids=[]):
+ netD = None
+ use_gpu = len(gpu_ids) > 0
+ assert(torch.cuda.is_available() == use_gpu)
+ if which_model_netD == 'basic':
+ netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
+ elif which_model_netD == 'n_layers':
+ netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, use_sigmoid, gpu_ids=gpu_ids)
+ else:
+ print('Discriminator model name [%s] is not recognized' %
+ which_model_netD)
+ if use_gpu:
+ netD.cuda()
+ netD.apply(weights_init)
+ return netD
+
+
+def print_network(net):
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ print(net)
+ print('Total number of parameters: %d' % num_params)
+
+
+##############################################################################
+# Classes
+##############################################################################
+
+
+# Defines the GAN loss used in LSGAN.
+# It is basically same as MSELoss, but it abstracts away the need to create
+# the target label tensor that has the same size as the input
+class GANLoss(nn.Module):
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
+ tensor=torch.FloatTensor):
+ super(GANLoss, self).__init__()
+ self.real_label = target_real_label
+ self.fake_label = target_fake_label
+ self.real_label_var = None
+ self.fake_label_var = None
+ self.Tensor = tensor
+ if use_lsgan:
+ self.loss = nn.MSELoss()
+ else:
+ self.loss = nn.BCELoss()
+
+ def get_target_tensor(self, input, target_is_real):
+ target_tensor = None
+ if target_is_real:
+ create_label = ((self.real_label_var is None) or
+ (self.real_label_var.numel() != input.numel()))
+ if create_label:
+ real_tensor = self.Tensor(input.size()).fill_(self.real_label)
+ self.real_label_var = Variable(real_tensor, requires_grad=False)
+ target_tensor = self.real_label_var
+ else:
+ create_label = ((self.fake_label_var is None) or
+ (self.fake_label_var.numel() != input.numel()))
+ if create_label:
+ fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
+ self.fake_label_var = Variable(fake_tensor, requires_grad=False)
+ target_tensor = self.fake_label_var
+ return target_tensor
+
+ def __call__(self, input, target_is_real):
+ target_tensor = self.get_target_tensor(input, target_is_real)
+ return self.loss(input, target_tensor)
+
+
+# Defines the generator that consists of Resnet blocks between a few
+# downsampling/upsampling operations.
+# Code and idea originally from Justin Johnson's architecture.
+# https://github.com/jcjohnson/fast-neural-style/
+class ResnetGenerator(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, n_blocks=6, gpu_ids=[]):
+ assert(n_blocks >= 0)
+ super(ResnetGenerator, self).__init__()
+ self.input_nc = input_nc
+ self.output_nc = output_nc
+ self.ngf = ngf
+ self.gpu_ids = gpu_ids
+
+ model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ for i in range(n_downsampling):
+ mult = 2**i
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
+ stride=2, padding=1),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+
+ mult = 2**n_downsampling
+ for i in range(n_blocks):
+ model += [Resnet_block(ngf * mult, 'zero', norm_layer=norm_layer)]
+
+ for i in range(n_downsampling):
+ mult = 2**(n_downsampling - i)
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=3)]
+ model += [nn.Tanh()]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
+ else:
+ return self.model(input)
+
+
+
+# Define a resnet block
+class Resnet_block(nn.Module):
+ def __init__(self, dim, padding_type, norm_layer):
+ super(Resnet_block, self).__init__()
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer)
+
+ def build_conv_block(self, dim, padding_type, norm_layer):
+ conv_block = []
+ p = 0
+ # TODO: support padding types
+ assert(padding_type == 'zero')
+ p = 1
+
+ # TODO: InstanceNorm
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
+ norm_layer(dim),
+ nn.ReLU(True)]
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
+ norm_layer(dim)]
+
+ return nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ out = x + self.conv_block(x)
+ return out
+
+
+# Defines the Unet geneator.
+class UnetGenerator(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, gpu_ids=[]):
+ super(UnetGenerator, self).__init__()
+ self.input_nc = input_nc
+ self.output_nc = output_nc
+ self.ngf = ngf
+ self.gpu_ids = gpu_ids
+
+ def forward(self, input):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
+ else:
+ return self.model(input)
+
+
+
+# Defines the PatchGAN discriminator with the specified arguments.
+class NLayerDiscriminator(nn.Module):
+ def __init__(self, input_nc, ndf=64, n_layers=3, use_sigmoid=False, gpu_ids=[]):
+ super(NLayerDiscriminator, self).__init__()
+ self.gpu_ids = gpu_ids
+
+ kw = 4
+ sequence = [
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=2),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers):
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
+ kernel_size=kw, stride=2, padding=2),
+ # TODO: use InstanceNorm
+ nn.BatchNorm2d(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
+ kernel_size=1, stride=2, padding=2),
+ # TODO: useInstanceNorm
+ nn.BatchNorm2d(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=1)]
+ sequence += [nn.Sigmoid()]
+
+ self.model = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
+ else:
+ return self.model(input)
+
+# Instance Normalization layer from
+# https://github.com/darkstar112358/fast-neural-style
+
+class InstanceNormalization(torch.nn.Module):
+ """InstanceNormalization
+ Improves convergence of neural-style.
+ ref: https://arxiv.org/pdf/1607.08022.pdf
+ """
+
+ def __init__(self, dim, eps=1e-5):
+ super(InstanceNormalization, self).__init__()
+ self.weight = nn.Parameter(torch.FloatTensor(dim))
+ self.bias = nn.Parameter(torch.FloatTensor(dim))
+ self.eps = eps
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ self.weight.data.uniform_()
+ self.bias.data.zero_()
+
+ def forward(self, x):
+ n = x.size(2) * x.size(3)
+ t = x.view(x.size(0), x.size(1), n)
+ mean = torch.mean(t, 2).unsqueeze(2).expand_as(x)
+ # Calculate the biased var. torch.var returns unbiased var
+ var = torch.var(t, 2).unsqueeze(2).expand_as(x) * ((n - 1) / float(n))
+ scale_broadcast = self.weight.unsqueeze(1).unsqueeze(1).unsqueeze(0)
+ scale_broadcast = scale_broadcast.expand_as(x)
+ shift_broadcast = self.bias.unsqueeze(1).unsqueeze(1).unsqueeze(0)
+ shift_broadcast = shift_broadcast.expand_as(x)
+ out = (x - mean) / torch.sqrt(var + self.eps)
+ out = out * scale_broadcast + shift_broadcast
+ return out
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
new file mode 100644
index 0000000..1d89b29
--- /dev/null
+++ b/models/pix2pix_model.py
@@ -0,0 +1,147 @@
+import numpy as np
+import torch
+import os
+from collections import OrderedDict
+from pdb import set_trace as st
+from torch.autograd import Variable
+import util.util as util
+from util.image_pool import ImagePool
+from .base_model import BaseModel
+from . import networks
+
+
+class Pix2PixModel(BaseModel):
+ def name(self):
+ return 'Pix2PixModel'
+
+ def initialize(self, opt):
+ BaseModel.initialize(self, opt)
+ self.isTrain = opt.isTrain
+ # define tensors
+ self.input_A = self.Tensor(opt.batchSize, opt.input_nc,
+ opt.fineSize, opt.fineSize)
+ self.input_B = self.Tensor(opt.batchSize, opt.output_nc,
+ opt.fineSize, opt.fineSize)
+
+ # load/define networks
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
+ opt.which_model_netG, opt.norm, self.gpu_ids)
+ if self.isTrain:
+ use_sigmoid = opt.no_lsgan
+ self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
+ opt.which_model_netD,
+ opt.n_layers_D, use_sigmoid, self.gpu_ids)
+ if not self.isTrain or opt.continue_train:
+ self.load_network(self.netG, 'G', opt.which_epoch)
+ if self.isTrain:
+ self.load_network(self.netD, 'D', opt.which_epoch)
+
+ if self.isTrain:
+ self.fake_AB_pool = ImagePool(opt.pool_size)
+ self.old_lr = opt.lr
+ # define loss functions
+ self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
+ self.criterionL1 = torch.nn.L1Loss()
+
+ # initialize optimizers
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
+ lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
+ lr=opt.lr, betas=(opt.beta1, 0.999))
+
+ print('---------- Networks initialized -------------')
+ networks.print_network(self.netG)
+ networks.print_network(self.netD)
+ print('-----------------------------------------------')
+
+ def set_input(self, input):
+ AtoB = self.opt.which_direction is 'AtoB'
+ input_A = input['A' if AtoB else 'B']
+ input_B = input['B' if AtoB else 'A']
+ self.input_A.resize_(input_A.size()).copy_(input_A)
+ self.input_B.resize_(input_B.size()).copy_(input_B)
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
+
+ def forward(self):
+ self.real_A = Variable(self.input_A)
+ self.fake_B = self.netG.forward(self.real_A)
+ self.real_B = Variable(self.input_B)
+
+ # no backprop gradients
+ def test(self):
+ self.real_A = Variable(self.input_A, volatile=True)
+ self.fake_B = self.netG.forward(self.real_A)
+ self.real_B = Variable(self.input_B, volatile=True)
+
+ #get image paths
+ def get_image_paths(self):
+ return self.image_paths
+
+ def backward_D(self):
+ # Fake
+ # stop backprop to the generator by detaching fake_B
+ fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
+ self.pred_fake = self.netD.forward(fake_AB.detach())
+ self.loss_D_fake = self.criterionGAN(self.pred_fake, False)
+
+ # Real
+ real_AB = torch.cat((self.real_A, self.real_B), 1)#.detach()
+ self.pred_real = self.netD.forward(real_AB)
+ self.loss_D_real = self.criterionGAN(self.pred_real, True)
+
+ # Combined loss
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
+
+ self.loss_D.backward()
+
+ def backward_G(self):
+ # First, G(A) should fake the discriminator
+ fake_AB = torch.cat((self.real_A, self.fake_B), 1)
+ pred_fake = self.netD.forward(fake_AB)
+ self.loss_G_GAN = self.criterionGAN(pred_fake, True)
+
+ # Second, G(A) = B
+ self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A
+
+ self.loss_G = self.loss_G_GAN + self.loss_G_L1
+
+ self.loss_G.backward()
+
+ def optimize_parameters(self):
+ self.forward()
+
+ self.optimizer_D.zero_grad()
+ self.backward_D()
+ self.optimizer_D.step()
+
+ self.optimizer_G.zero_grad()
+ self.backward_G()
+ self.optimizer_G.step()
+
+ def get_current_errors(self):
+ return OrderedDict([('G_GAN', self.loss_G_GAN.data[0]),
+ ('G_L1', self.loss_G_L1.data[0]),
+ ('D_real', self.loss_D_real.data[0]),
+ ('D_fake', self.loss_D_fake.data[0])
+ ])
+
+ def get_current_visuals(self):
+ real_A = util.tensor2im(self.real_A.data)
+ fake_B = util.tensor2im(self.fake_B.data)
+ real_B = util.tensor2im(self.real_B.data)
+ return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)])
+
+ def save(self, label):
+ use_gpu = self.gpu_ids is not None
+ self.save_network(self.netG, 'G', label, use_gpu)
+ self.save_network(self.netD, 'D', label, use_gpu)
+
+ def update_learning_rate(self):
+ lrd = self.opt.lr / self.opt.niter_decay
+ lr = self.old_lr - lrd
+ for param_group in self.optimizer_D.param_groups:
+ param_group['lr'] = lr
+ for param_group in self.optimizer_G.param_groups:
+ param_group['lr'] = lr
+ print('update learning rate: %f -> %f' % (self.old_lr, lr))
+ self.old_lr = lr