summaryrefslogtreecommitdiff
path: root/models/networks.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/networks.py')
-rw-r--r--models/networks.py112
1 files changed, 90 insertions, 22 deletions
diff --git a/models/networks.py b/models/networks.py
index 6cf4169..2df58fe 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -3,21 +3,74 @@ import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
+from torch.optim import lr_scheduler
import numpy as np
###############################################################################
# Functions
###############################################################################
-def weights_init(m):
+
+def weights_init_normal(m):
+ classname = m.__class__.__name__
+ # print(classname)
+ if classname.find('Conv') != -1:
+ init.uniform(m.weight.data, 0.0, 0.02)
+ elif classname.find('Linear') != -1:
+ init.uniform(m.weight.data, 0.0, 0.02)
+ elif classname.find('BatchNorm2d') != -1:
+ init.uniform(m.weight.data, 1.0, 0.02)
+ init.constant(m.bias.data, 0.0)
+
+
+def weights_init_xavier(m):
+ classname = m.__class__.__name__
+ # print(classname)
+ if classname.find('Conv') != -1:
+ init.xavier_normal(m.weight.data, gain=1)
+ elif classname.find('Linear') != -1:
+ init.xavier_normal(m.weight.data, gain=1)
+ elif classname.find('BatchNorm2d') != -1:
+ init.uniform(m.weight.data, 1.0, 0.02)
+ init.constant(m.bias.data, 0.0)
+
+
+def weights_init_kaiming(m):
+ classname = m.__class__.__name__
+ # print(classname)
+ if classname.find('Conv') != -1:
+ init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
+ elif classname.find('Linear') != -1:
+ init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
+ elif classname.find('BatchNorm2d') != -1:
+ init.uniform(m.weight.data, 1.0, 0.02)
+ init.constant(m.bias.data, 0.0)
+
+
+def weights_init_orthogonal(m):
classname = m.__class__.__name__
+ print(classname)
if classname.find('Conv') != -1:
- m.weight.data.normal_(0.0, 0.02)
- if hasattr(m.bias, 'data'):
- m.bias.data.fill_(0)
+ init.orthogonal(m.weight.data, gain=1)
+ elif classname.find('Linear') != -1:
+ init.orthogonal(m.weight.data, gain=1)
elif classname.find('BatchNorm2d') != -1:
- m.weight.data.normal_(1.0, 0.02)
- m.bias.data.fill_(0)
+ init.uniform(m.weight.data, 1.0, 0.02)
+ init.constant(m.bias.data, 0.0)
+
+
+def init_weights(net, init_type='normal'):
+ print('initialization method [%s]' % init_type)
+ if init_type == 'normal':
+ net.apply(weights_init_normal)
+ elif init_type == 'xavier':
+ net.apply(weights_init_xavier)
+ elif init_type == 'kaiming':
+ net.apply(weights_init_kaiming)
+ elif init_type == 'orthogonal':
+ net.apply(weights_init_orthogonal)
+ else:
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
def get_norm_layer(norm_type='instance'):
@@ -25,12 +78,29 @@ def get_norm_layer(norm_type='instance'):
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
+ elif layer_type == 'none':
+ norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
-def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[]):
+def get_scheduler(optimizer, opt):
+ if opt.lr_policy == 'lambda':
+ def lambda_rule(epoch):
+ lr_l = 1.0 - max(0, epoch - opt.niter) / float(opt.niter_decay+1)
+ return lr_l
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
+ elif opt.lr_policy == 'step':
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
+ elif opt.lr_policy == 'plateau':
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
+ else:
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
+ return scheduler
+
+
+def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]):
netG = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
@@ -50,12 +120,12 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) > 0:
netG.cuda(device_id=gpu_ids[0])
- netG.apply(weights_init)
+ init_weights(netG, init_type=init_type)
return netG
def define_D(input_nc, ndf, which_model_netD,
- n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[]):
+ n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]):
netD = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
@@ -71,7 +141,7 @@ def define_D(input_nc, ndf, which_model_netD,
which_model_netD)
if use_gpu:
netD.cuda(device_id=gpu_ids[0])
- netD.apply(weights_init)
+ init_weights(netD, init_type=init_type)
return netD
@@ -238,17 +308,14 @@ class UnetGenerator(nn.Module):
super(UnetGenerator, self).__init__()
self.gpu_ids = gpu_ids
- # currently support only input_nc == output_nc
- assert(input_nc == output_nc)
-
# construct unet structure
- unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True)
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
for i in range(num_downs - 5):
- unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
- unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer)
- unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer)
- unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer)
- unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
self.model = unet_block
@@ -263,7 +330,7 @@ class UnetGenerator(nn.Module):
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
- def __init__(self, outer_nc, inner_nc,
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
@@ -271,8 +338,9 @@ class UnetSkipConnectionBlock(nn.Module):
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
-
- downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
+ if input_nc is None:
+ input_nc = outer_nc
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)