diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-10-06 10:46:43 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-10-06 10:46:43 -0700 |
| commit | 7800d516596f1a25986b458cddf8b8785bcc7df8 (patch) | |
| tree | 56d57350e7104393f939ec7cc2e07c96840aaa27 | |
| parent | e986144cee13a921fd3ad68d564f820e8f7dd3b0 (diff) | |
support nc=1, add new leaerning rate policy and new initialization
| -rw-r--r-- | data/aligned_dataset.py | 15 | ||||
| -rw-r--r-- | data/single_dataset.py | 12 | ||||
| -rw-r--r-- | data/unaligned_dataset.py | 20 | ||||
| -rw-r--r-- | models/base_model.py | 5 | ||||
| -rw-r--r-- | models/cycle_gan_model.py | 20 | ||||
| -rw-r--r-- | models/networks.py | 112 | ||||
| -rw-r--r-- | models/pix2pix_model.py | 20 | ||||
| -rw-r--r-- | models/test_model.py | 1 | ||||
| -rw-r--r-- | options/base_options.py | 1 | ||||
| -rw-r--r-- | options/train_options.py | 3 | ||||
| -rw-r--r-- | util/util.py | 2 | ||||
| -rw-r--r-- | util/visualizer.py | 3 |
12 files changed, 157 insertions, 57 deletions
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py index 0f45c40..bccd6fc 100644 --- a/data/aligned_dataset.py +++ b/data/aligned_dataset.py @@ -40,12 +40,27 @@ class AlignedDataset(BaseDataset): B = AB[:, h_offset:h_offset + self.opt.fineSize, w + w_offset:w + w_offset + self.opt.fineSize] + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + output_nc = self.opt.input_nc + else: + input_nc = self.opt.input_nc + output_nc = self.opt.output_nc + if (not self.opt.no_flip) and random.random() < 0.5: idx = [i for i in range(A.size(2) - 1, -1, -1)] idx = torch.LongTensor(idx) A = A.index_select(2, idx) B = B.index_select(2, idx) + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) + + if output_nc == 1: # RGB to gray + tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 + B = tmp.unsqueeze(0) + return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} diff --git a/data/single_dataset.py b/data/single_dataset.py index faf416a..f8b4f1d 100644 --- a/data/single_dataset.py +++ b/data/single_dataset.py @@ -19,12 +19,18 @@ class SingleDataset(BaseDataset): def __getitem__(self, index): A_path = self.A_paths[index] - A_img = Image.open(A_path).convert('RGB') + A = self.transform(A_img) + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + else: + input_nc = self.opt.input_nc - A_img = self.transform(A_img) + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) - return {'A': A_img, 'A_paths': A_path} + return {'A': A, 'A_paths': A_path} def __len__(self): return len(self.A_paths) diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py index d31eb05..c5e5460 100644 --- a/data/unaligned_dataset.py +++ b/data/unaligned_dataset.py @@ -4,7 +4,6 @@ from data.base_dataset import BaseDataset, get_transform from data.image_folder import make_dataset from PIL import Image import PIL -from pdb import set_trace as st import random class UnalignedDataset(BaseDataset): @@ -32,10 +31,23 @@ class UnalignedDataset(BaseDataset): A_img = Image.open(A_path).convert('RGB') B_img = Image.open(B_path).convert('RGB') - A_img = self.transform(A_img) - B_img = self.transform(B_img) + A = self.transform(A_img) + B = self.transform(B_img) + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + output_nc = self.opt.input_nc + else: + input_nc = self.opt.input_nc + output_nc = self.opt.output_nc - return {'A': A_img, 'B': B_img, + if input_nc == 1: # RGB to gray + tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 + A = tmp.unsqueeze(0) + + if output_nc == 1: # RGB to gray + tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 + B = tmp.unsqueeze(0) + return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path} def __len__(self): diff --git a/models/base_model.py b/models/base_model.py index 36ceb43..55da1ca 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -53,4 +53,7 @@ class BaseModel(): network.load_state_dict(torch.load(save_path)) def update_learning_rate(): - pass + for scheduler in self.schedulers: + scheduler.step() + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index b3c52c7..c6b336c 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -61,6 +61,13 @@ class CycleGANModel(BaseModel): lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers = [] + self.schedulers = [] + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D_A) + self.optimizers.append(self.optimizer_D_B) + for optimizer in self.optimizers: + self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG_A) @@ -204,16 +211,3 @@ class CycleGANModel(BaseModel): self.save_network(self.netD_A, 'D_A', label, self.gpu_ids) self.save_network(self.netG_B, 'G_B', label, self.gpu_ids) self.save_network(self.netD_B, 'D_B', label, self.gpu_ids) - - def update_learning_rate(self): - lrd = self.opt.lr / self.opt.niter_decay - lr = self.old_lr - lrd - for param_group in self.optimizer_D_A.param_groups: - param_group['lr'] = lr - for param_group in self.optimizer_D_B.param_groups: - param_group['lr'] = lr - for param_group in self.optimizer_G.param_groups: - param_group['lr'] = lr - - print('update learning rate: %f -> %f' % (self.old_lr, lr)) - self.old_lr = lr diff --git a/models/networks.py b/models/networks.py index 6cf4169..2df58fe 100644 --- a/models/networks.py +++ b/models/networks.py @@ -3,21 +3,74 @@ import torch.nn as nn from torch.nn import init import functools from torch.autograd import Variable +from torch.optim import lr_scheduler import numpy as np ############################################################################### # Functions ############################################################################### -def weights_init(m): + +def weights_init_normal(m): + classname = m.__class__.__name__ + # print(classname) + if classname.find('Conv') != -1: + init.uniform(m.weight.data, 0.0, 0.02) + elif classname.find('Linear') != -1: + init.uniform(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_xavier(m): + classname = m.__class__.__name__ + # print(classname) + if classname.find('Conv') != -1: + init.xavier_normal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.xavier_normal(m.weight.data, gain=1) + elif classname.find('BatchNorm2d') != -1: + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + # print(classname) + if classname.find('Conv') != -1: + init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('Linear') != -1: + init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('BatchNorm2d') != -1: + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_orthogonal(m): classname = m.__class__.__name__ + print(classname) if classname.find('Conv') != -1: - m.weight.data.normal_(0.0, 0.02) - if hasattr(m.bias, 'data'): - m.bias.data.fill_(0) + init.orthogonal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.orthogonal(m.weight.data, gain=1) elif classname.find('BatchNorm2d') != -1: - m.weight.data.normal_(1.0, 0.02) - m.bias.data.fill_(0) + init.uniform(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def init_weights(net, init_type='normal'): + print('initialization method [%s]' % init_type) + if init_type == 'normal': + net.apply(weights_init_normal) + elif init_type == 'xavier': + net.apply(weights_init_xavier) + elif init_type == 'kaiming': + net.apply(weights_init_kaiming) + elif init_type == 'orthogonal': + net.apply(weights_init_orthogonal) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) def get_norm_layer(norm_type='instance'): @@ -25,12 +78,29 @@ def get_norm_layer(norm_type='instance'): norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif layer_type == 'none': + norm_layer = None else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer -def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[]): +def get_scheduler(optimizer, opt): + if opt.lr_policy == 'lambda': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch - opt.niter) / float(opt.niter_decay+1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]): netG = None use_gpu = len(gpu_ids) > 0 norm_layer = get_norm_layer(norm_type=norm) @@ -50,12 +120,12 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: netG.cuda(device_id=gpu_ids[0]) - netG.apply(weights_init) + init_weights(netG, init_type=init_type) return netG def define_D(input_nc, ndf, which_model_netD, - n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[]): + n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]): netD = None use_gpu = len(gpu_ids) > 0 norm_layer = get_norm_layer(norm_type=norm) @@ -71,7 +141,7 @@ def define_D(input_nc, ndf, which_model_netD, which_model_netD) if use_gpu: netD.cuda(device_id=gpu_ids[0]) - netD.apply(weights_init) + init_weights(netD, init_type=init_type) return netD @@ -238,17 +308,14 @@ class UnetGenerator(nn.Module): super(UnetGenerator, self).__init__() self.gpu_ids = gpu_ids - # currently support only input_nc == output_nc - assert(input_nc == output_nc) - # construct unet structure - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True) + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) for i in range(num_downs - 5): - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer, use_dropout=use_dropout) - unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) self.model = unet_block @@ -263,7 +330,7 @@ class UnetGenerator(nn.Module): # X -------------------identity---------------------- X # |-- downsampling -- |submodule| -- upsampling --| class UnetSkipConnectionBlock(nn.Module): - def __init__(self, outer_nc, inner_nc, + def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost @@ -271,8 +338,9 @@ class UnetSkipConnectionBlock(nn.Module): use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d - - downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py index a524f2c..18ba53f 100644 --- a/models/pix2pix_model.py +++ b/models/pix2pix_model.py @@ -24,12 +24,12 @@ class Pix2PixModel(BaseModel): # load/define networks self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, - opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids) + opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.which_model_netD, - opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids) + opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: self.load_network(self.netG, 'G', opt.which_epoch) if self.isTrain: @@ -43,10 +43,16 @@ class Pix2PixModel(BaseModel): self.criterionL1 = torch.nn.L1Loss() # initialize optimizers + self.schedulers = [] + self.optimizers = [] self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D) + for optimizer in self.optimizers: + self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netG) @@ -134,13 +140,3 @@ class Pix2PixModel(BaseModel): def save(self, label): self.save_network(self.netG, 'G', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids) - - def update_learning_rate(self): - lrd = self.opt.lr / self.opt.niter_decay - lr = self.old_lr - lrd - for param_group in self.optimizer_D.param_groups: - param_group['lr'] = lr - for param_group in self.optimizer_G.param_groups: - param_group['lr'] = lr - print('update learning rate: %f -> %f' % (self.old_lr, lr)) - self.old_lr = lr diff --git a/models/test_model.py b/models/test_model.py index 03aef65..4af1fe1 100644 --- a/models/test_model.py +++ b/models/test_model.py @@ -17,6 +17,7 @@ class TestModel(BaseModel): self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, + opt.init_type, self.gpu_ids) which_epoch = opt.which_epoch self.load_network(self.netG, 'G', which_epoch) diff --git a/options/base_options.py b/options/base_options.py index de8bc74..c1b0733 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -39,6 +39,7 @@ class BaseOptions(): self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') self.parser.add_argument('--identity', type=float, default=0.0, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') + self.parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]') self.initialized = True diff --git a/options/train_options.py b/options/train_options.py index a595017..f8a0ff6 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -21,4 +21,7 @@ class TrainOptions(BaseOptions): self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') + self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') + self.isTrain = True diff --git a/util/util.py b/util/util.py index 781239f..4de0a74 100644 --- a/util/util.py +++ b/util/util.py @@ -11,6 +11,8 @@ import collections # |imtype|: the desired type of the converted numpy array def tensor2im(image_tensor, imtype=np.uint8): image_numpy = image_tensor[0].cpu().float().numpy() + if image_numpy.shape[0] == 1: + image_numpy = np.tile(image_numpy, (3, 1, 1)) image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 return image_numpy.astype(imtype) diff --git a/util/visualizer.py b/util/visualizer.py index 3733525..02a36b7 100644 --- a/util/visualizer.py +++ b/util/visualizer.py @@ -4,7 +4,7 @@ import ntpath import time from . import util from . import html - +from pdb import set_trace as st class Visualizer(): def __init__(self, opt): # self.opt = opt @@ -66,7 +66,6 @@ class Visualizer(): else: idx = 1 for label, image_numpy in visuals.items(): - #image_numpy = np.flipud(image_numpy) self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label), win=self.display_id + idx) idx += 1 |
