diff options
Diffstat (limited to 'models/networks.py')
| -rw-r--r-- | models/networks.py | 112 |
1 files changed, 90 insertions, 22 deletions
diff --git a/models/networks.py b/models/networks.py index 6cf4169..2df58fe 100644 --- a/models/networks.py +++ b/models/networks.py @@ -3,21 +3,74 @@ import torch.nn as nn from torch.nn import init import functools from torch.autograd import Variable +from torch.optim import lr_scheduler import numpy as np ############################################################################### # Functions ############################################################################### -def weights_init(m): + +def weights_init_normal(m): + classname = m.__class__.__name__ + # print(classname) + if classname.find('Conv') != -1: + init.uniform(m.weight.data, 0.0, 0.02) + elif classname.find('Linear') != -1: + init.uniform(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_xavier(m): + classname = m.__class__.__name__ + # print(classname) + if classname.find('Conv') != -1: + init.xavier_normal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.xavier_normal(m.weight.data, gain=1) + elif classname.find('BatchNorm2d') != -1: + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + # print(classname) + if classname.find('Conv') != -1: + init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('Linear') != -1: + init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('BatchNorm2d') != -1: + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_orthogonal(m): classname = m.__class__.__name__ + print(classname) if classname.find('Conv') != -1: - m.weight.data.normal_(0.0, 0.02) - if hasattr(m.bias, 'data'): - m.bias.data.fill_(0) + init.orthogonal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.orthogonal(m.weight.data, gain=1) elif classname.find('BatchNorm2d') != -1: - m.weight.data.normal_(1.0, 0.02) - m.bias.data.fill_(0) + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def init_weights(net, init_type='normal'): + print('initialization method [%s]' % init_type) + if init_type == 'normal': + net.apply(weights_init_normal) + elif init_type == 'xavier': + net.apply(weights_init_xavier) + elif init_type == 'kaiming': + net.apply(weights_init_kaiming) + elif init_type == 'orthogonal': + net.apply(weights_init_orthogonal) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) def get_norm_layer(norm_type='instance'): @@ -25,12 +78,29 @@ def get_norm_layer(norm_type='instance'): norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif layer_type == 'none': + norm_layer = None else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer -def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[]): +def get_scheduler(optimizer, opt): + if opt.lr_policy == 'lambda': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch - opt.niter) / float(opt.niter_decay+1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]): netG = None use_gpu = len(gpu_ids) > 0 norm_layer = get_norm_layer(norm_type=norm) @@ -50,12 +120,12 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: netG.cuda(device_id=gpu_ids[0]) - netG.apply(weights_init) + init_weights(netG, init_type=init_type) return netG def define_D(input_nc, ndf, which_model_netD, - n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[]): + n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]): netD = None use_gpu = len(gpu_ids) > 0 norm_layer = get_norm_layer(norm_type=norm) @@ -71,7 +141,7 @@ def define_D(input_nc, ndf, which_model_netD, which_model_netD) if use_gpu: netD.cuda(device_id=gpu_ids[0]) - netD.apply(weights_init) + init_weights(netD, init_type=init_type) return netD @@ -238,17 +308,14 @@ class UnetGenerator(nn.Module): super(UnetGenerator, self).__init__() self.gpu_ids = gpu_ids - # currently support only input_nc == output_nc - assert(input_nc == output_nc) - # construct unet structure - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True) + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) for i in range(num_downs - 5): - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout) - unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) self.model = unet_block @@ -263,7 +330,7 @@ class UnetGenerator(nn.Module): # X -------------------identity---------------------- X # |-- downsampling -- |submodule| -- upsampling --| class UnetSkipConnectionBlock(nn.Module): - def __init__(self, outer_nc, inner_nc, + def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost @@ -271,8 +338,9 @@ class UnetSkipConnectionBlock(nn.Module): use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d - - downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) |
