diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-10-06 10:46:43 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-10-06 10:46:43 -0700 |
| commit | 7800d516596f1a25986b458cddf8b8785bcc7df8 (patch) | |
| tree | 56d57350e7104393f939ec7cc2e07c96840aaa27 /models | |
| parent | e986144cee13a921fd3ad68d564f820e8f7dd3b0 (diff) | |
support nc=1, add new leaerning rate policy and new initialization
Diffstat (limited to 'models')
| -rw-r--r-- | models/base_model.py | 5 | ||||
| -rw-r--r-- | models/cycle_gan_model.py | 20 | ||||
| -rw-r--r-- | models/networks.py | 112 | ||||
| -rw-r--r-- | models/pix2pix_model.py | 20 | ||||
| -rw-r--r-- | models/test_model.py | 1 |
5 files changed, 110 insertions, 48 deletions
diff --git a/models/base_model.py b/models/base_model.py index 36ceb43..55da1ca 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -53,4 +53,7 @@ class BaseModel(): network.load_state_dict(torch.load(save_path)) def update_learning_rate(): - pass + for scheduler in self.schedulers: + scheduler.step() + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index b3c52c7..c6b336c 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -61,6 +61,13 @@ class CycleGANModel(BaseModel): lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers = [] + self.schedulers = [] + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D_A) + self.optimizers.append(self.optimizer_D_B) + for optimizer in self.optimizers: + self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) @@ -204,16 +211,3 @@ class CycleGANModel(BaseModel): self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) - - def update_learning_rate(self): - lrd = self.opt.lr / self.opt.niter_decay - lr = self.old_lr - lrd - for param_group in self.optimizer_D_A.param_groups: - param_group['lr'] = lr - for param_group in self.optimizer_D_B.param_groups: - param_group['lr'] = lr - for param_group in self.optimizer_G.param_groups: - param_group['lr'] = lr - - print('update learning rate: %f -> %f' % (self.old_lr, lr)) - self.old_lr = lr diff --git a/models/networks.py b/models/networks.py index 6cf4169..2df58fe 100644 --- a/models/networks.py +++ b/models/networks.py @@ -3,21 +3,74 @@ import torch.nn as nn from torch.nn import init import functools from torch.autograd import Variable +from torch.optim import lr_scheduler import numpy as np ############################################################################### # Functions ############################################################################### -def weights_init(m): + +def weights_init_normal(m): + classname = m.__class__.__name__ + # print(classname) + if classname.find('Conv') != -1: + init.uniform(m.weight.data, 0.0, 0.02) + elif classname.find('Linear') != -1: + init.uniform(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_xavier(m): + classname = m.__class__.__name__ + # print(classname) + if classname.find('Conv') != -1: + init.xavier_normal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.xavier_normal(m.weight.data, gain=1) + elif classname.find('BatchNorm2d') != -1: + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + # print(classname) + if classname.find('Conv') != -1: + init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('Linear') != -1: + init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('BatchNorm2d') != -1: + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_orthogonal(m): classname = m.__class__.__name__ + print(classname) if classname.find('Conv') != -1: - m.weight.data.normal_(0.0, 0.02) - if hasattr(m.bias, 'data'): - m.bias.data.fill_(0) + init.orthogonal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.orthogonal(m.weight.data, gain=1) elif classname.find('BatchNorm2d') != -1: - m.weight.data.normal_(1.0, 0.02) - m.bias.data.fill_(0) + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def init_weights(net, init_type='normal'): + print('initialization method [%s]' % init_type) + if init_type == 'normal': + net.apply(weights_init_normal) + elif init_type == 'xavier': + net.apply(weights_init_xavier) + elif init_type == 'kaiming': + net.apply(weights_init_kaiming) + elif init_type == 'orthogonal': + net.apply(weights_init_orthogonal) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) def get_norm_layer(norm_type='instance'): @@ -25,12 +78,29 @@ def get_norm_layer(norm_type='instance'): norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif layer_type == 'none': + norm_layer = None else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer -def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[]): +def get_scheduler(optimizer, opt): + if opt.lr_policy == 'lambda': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch - opt.niter) / float(opt.niter_decay+1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]): netG = None use_gpu = len(gpu_ids) > 0 norm_layer = get_norm_layer(norm_type=norm) @@ -50,12 +120,12 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: netG.cuda(device_id=gpu_ids[0]) - netG.apply(weights_init) + init_weights(netG, init_type=init_type) return netG def define_D(input_nc, ndf, which_model_netD, - n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[]): + n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]): netD = None use_gpu = len(gpu_ids) > 0 norm_layer = get_norm_layer(norm_type=norm) @@ -71,7 +141,7 @@ def define_D(input_nc, ndf, which_model_netD, which_model_netD) if use_gpu: netD.cuda(device_id=gpu_ids[0]) - netD.apply(weights_init) + init_weights(netD, init_type=init_type) return netD @@ -238,17 +308,14 @@ class UnetGenerator(nn.Module): super(UnetGenerator, self).__init__() self.gpu_ids = gpu_ids - # currently support only input_nc == output_nc - assert(input_nc == output_nc) - # construct unet structure - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True) + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) for i in range(num_downs - 5): - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout) - unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) self.model = unet_block @@ -263,7 +330,7 @@ class UnetGenerator(nn.Module): # X -------------------identity---------------------- X # |-- downsampling -- |submodule| -- upsampling --| class UnetSkipConnectionBlock(nn.Module): - def __init__(self, outer_nc, inner_nc, + def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost @@ -271,8 +338,9 @@ class UnetSkipConnectionBlock(nn.Module): use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d - - downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index a524f2c..18ba53f 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -24,12 +24,12 @@ class Pix2PixModel(BaseModel): # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, - opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids) + opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, - opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) + opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: @@ -43,10 +43,16 @@ class Pix2PixModel(BaseModel): self.criterionL1 = torch.nn.L1Loss() # initialize optimizers + self.schedulers = [] + self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D) + for optimizer in self.optimizers: + self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG) @@ -134,13 +140,3 @@ class Pix2PixModel(BaseModel): def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids) - - def update_learning_rate(self): - lrd = self.opt.lr / self.opt.niter_decay - lr = self.old_lr - lrd - for param_group in self.optimizer_D.param_groups: - param_group['lr'] = lr - for param_group in self.optimizer_G.param_groups: - param_group['lr'] = lr - print('update learning rate: %f -> %f' % (self.old_lr, lr)) - self.old_lr = lr diff --git a/models/test_model.py b/models/test_model.py index 03aef65..4af1fe1 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -17,6 +17,7 @@ class TestModel(BaseModel): self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, + opt.init_type, self.gpu_ids) which_epoch = opt.which_epoch self.load_network(self.netG, 'G', which_epoch) |
