summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/base_model.py5
-rw-r--r--models/cycle_gan_model.py20
-rw-r--r--models/networks.py112
-rw-r--r--models/pix2pix_model.py20
-rw-r--r--models/test_model.py1
5 files changed, 110 insertions, 48 deletions
diff --git a/models/base_model.py b/models/base_model.py
index 36ceb43..55da1ca 100644
--- a/models/base_model.py
+++ b/models/base_model.py
@@ -53,4 +53,7 @@ class BaseModel():
network.load_state_dict(torch.load(save_path))
def update_learning_rate():
- pass
+ for scheduler in self.schedulers:
+ scheduler.step()
+ lr = self.optimizers[0].param_groups[0]['lr']
+ print('learning rate = %.7f' % lr)
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index b3c52c7..c6b336c 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -61,6 +61,13 @@ class CycleGANModel(BaseModel):
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizers = []
+ self.schedulers = []
+ self.optimizers.append(self.optimizer_G)
+ self.optimizers.append(self.optimizer_D_A)
+ self.optimizers.append(self.optimizer_D_B)
+ for optimizer in self.optimizers:
+ self.schedulers.append(networks.get_scheduler(optimizer, opt))
print('---------- Networks initialized -------------')
networks.print_network(self.netG_A)
@@ -204,16 +211,3 @@ class CycleGANModel(BaseModel):
self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
-
- def update_learning_rate(self):
- lrd = self.opt.lr / self.opt.niter_decay
- lr = self.old_lr - lrd
- for param_group in self.optimizer_D_A.param_groups:
- param_group['lr'] = lr
- for param_group in self.optimizer_D_B.param_groups:
- param_group['lr'] = lr
- for param_group in self.optimizer_G.param_groups:
- param_group['lr'] = lr
-
- print('update learning rate: %f -> %f' % (self.old_lr, lr))
- self.old_lr = lr
diff --git a/models/networks.py b/models/networks.py
index 6cf4169..2df58fe 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -3,21 +3,74 @@ import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
+from torch.optim import lr_scheduler
import numpy as np
###############################################################################
# Functions
###############################################################################
-def weights_init(m):
+
+def weights_init_normal(m):
+ classname = m.__class__.__name__
+ # print(classname)
+ if classname.find('Conv') != -1:
+ init.uniform(m.weight.data, 0.0, 0.02)
+ elif classname.find('Linear') != -1:
+ init.uniform(m.weight.data, 0.0, 0.02)
+ elif classname.find('BatchNorm2d') != -1:
+ init.uniform(m.weight.data, 1.0, 0.02)
+ init.constant(m.bias.data, 0.0)
+
+
+def weights_init_xavier(m):
+ classname = m.__class__.__name__
+ # print(classname)
+ if classname.find('Conv') != -1:
+ init.xavier_normal(m.weight.data, gain=1)
+ elif classname.find('Linear') != -1:
+ init.xavier_normal(m.weight.data, gain=1)
+ elif classname.find('BatchNorm2d') != -1:
+ init.uniform(m.weight.data, 1.0, 0.02)
+ init.constant(m.bias.data, 0.0)
+
+
+def weights_init_kaiming(m):
+ classname = m.__class__.__name__
+ # print(classname)
+ if classname.find('Conv') != -1:
+ init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
+ elif classname.find('Linear') != -1:
+ init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
+ elif classname.find('BatchNorm2d') != -1:
+ init.uniform(m.weight.data, 1.0, 0.02)
+ init.constant(m.bias.data, 0.0)
+
+
+def weights_init_orthogonal(m):
classname = m.__class__.__name__
+ print(classname)
if classname.find('Conv') != -1:
- m.weight.data.normal_(0.0, 0.02)
- if hasattr(m.bias, 'data'):
- m.bias.data.fill_(0)
+ init.orthogonal(m.weight.data, gain=1)
+ elif classname.find('Linear') != -1:
+ init.orthogonal(m.weight.data, gain=1)
elif classname.find('BatchNorm2d') != -1:
- m.weight.data.normal_(1.0, 0.02)
- m.bias.data.fill_(0)
+ init.uniform(m.weight.data, 1.0, 0.02)
+ init.constant(m.bias.data, 0.0)
+
+
+def init_weights(net, init_type='normal'):
+ print('initialization method [%s]' % init_type)
+ if init_type == 'normal':
+ net.apply(weights_init_normal)
+ elif init_type == 'xavier':
+ net.apply(weights_init_xavier)
+ elif init_type == 'kaiming':
+ net.apply(weights_init_kaiming)
+ elif init_type == 'orthogonal':
+ net.apply(weights_init_orthogonal)
+ else:
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
def get_norm_layer(norm_type='instance'):
@@ -25,12 +78,29 @@ def get_norm_layer(norm_type='instance'):
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
+ elif layer_type == 'none':
+ norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
-def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[]):
+def get_scheduler(optimizer, opt):
+ if opt.lr_policy == 'lambda':
+ def lambda_rule(epoch):
+ lr_l = 1.0 - max(0, epoch - opt.niter) / float(opt.niter_decay+1)
+ return lr_l
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
+ elif opt.lr_policy == 'step':
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
+ elif opt.lr_policy == 'plateau':
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
+ else:
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
+ return scheduler
+
+
+def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]):
netG = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
@@ -50,12 +120,12 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) > 0:
netG.cuda(device_id=gpu_ids[0])
- netG.apply(weights_init)
+ init_weights(netG, init_type=init_type)
return netG
def define_D(input_nc, ndf, which_model_netD,
- n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[]):
+ n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]):
netD = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
@@ -71,7 +141,7 @@ def define_D(input_nc, ndf, which_model_netD,
which_model_netD)
if use_gpu:
netD.cuda(device_id=gpu_ids[0])
- netD.apply(weights_init)
+ init_weights(netD, init_type=init_type)
return netD
@@ -238,17 +308,14 @@ class UnetGenerator(nn.Module):
super(UnetGenerator, self).__init__()
self.gpu_ids = gpu_ids
- # currently support only input_nc == output_nc
- assert(input_nc == output_nc)
-
# construct unet structure
- unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True)
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
for i in range(num_downs - 5):
- unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
- unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer)
- unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer)
- unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer)
- unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
self.model = unet_block
@@ -263,7 +330,7 @@ class UnetGenerator(nn.Module):
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
- def __init__(self, outer_nc, inner_nc,
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
@@ -271,8 +338,9 @@ class UnetSkipConnectionBlock(nn.Module):
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
-
- downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
+ if input_nc is None:
+ input_nc = outer_nc
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
index a524f2c..18ba53f 100644
--- a/models/pix2pix_model.py
+++ b/models/pix2pix_model.py
@@ -24,12 +24,12 @@ class Pix2PixModel(BaseModel):
# load/define networks
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
- opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids)
+ opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
if self.isTrain:
use_sigmoid = opt.no_lsgan
self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
opt.which_model_netD,
- opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids)
+ opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
if not self.isTrain or opt.continue_train:
self.load_network(self.netG, 'G', opt.which_epoch)
if self.isTrain:
@@ -43,10 +43,16 @@ class Pix2PixModel(BaseModel):
self.criterionL1 = torch.nn.L1Loss()
# initialize optimizers
+ self.schedulers = []
+ self.optimizers = []
self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizers.append(self.optimizer_G)
+ self.optimizers.append(self.optimizer_D)
+ for optimizer in self.optimizers:
+ self.schedulers.append(networks.get_scheduler(optimizer, opt))
print('---------- Networks initialized -------------')
networks.print_network(self.netG)
@@ -134,13 +140,3 @@ class Pix2PixModel(BaseModel):
def save(self, label):
self.save_network(self.netG, 'G', label, self.gpu_ids)
self.save_network(self.netD, 'D', label, self.gpu_ids)
-
- def update_learning_rate(self):
- lrd = self.opt.lr / self.opt.niter_decay
- lr = self.old_lr - lrd
- for param_group in self.optimizer_D.param_groups:
- param_group['lr'] = lr
- for param_group in self.optimizer_G.param_groups:
- param_group['lr'] = lr
- print('update learning rate: %f -> %f' % (self.old_lr, lr))
- self.old_lr = lr
diff --git a/models/test_model.py b/models/test_model.py
index 03aef65..4af1fe1 100644
--- a/models/test_model.py
+++ b/models/test_model.py
@@ -17,6 +17,7 @@ class TestModel(BaseModel):
self.netG = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG,
opt.norm, not opt.no_dropout,
+ opt.init_type,
self.gpu_ids)
which_epoch = opt.which_epoch
self.load_network(self.netG, 'G', which_epoch)