summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-04-18 03:56:29 -0700
committerjunyanz <junyanz@berkeley.edu>2017-04-18 03:56:29 -0700
commitefef2906b0a3f4fd265823ff4b4b99ccebeb6d05 (patch)
tree0b041af61c33b2262b336598b9ff402595ed5f3b
parente17139daa9f3af07acee72dbf186d7eaf9ad089e (diff)
add unet and update README
-rw-r--r--README.md3
-rw-r--r--models/networks.py98
-rw-r--r--util/visualizer.py2
3 files changed, 84 insertions, 19 deletions
diff --git a/README.md b/README.md
index be327fe..6b5b44b 100644
--- a/README.md
+++ b/README.md
@@ -79,11 +79,10 @@ The test results will be saved to a html file here: `./results/facades_pix2pix/l
More example scripts can be found at `scripts` directory.
-
## Training/test Details
- See `options/train_options.py` and `options/base_options.py` for training flags; see `optoins/test_options.py` and `options/base_options.py` for test flags.
- CPU/GPU: Set `--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode.
-- If you set `--display_id 0`, we will save the training results to `../checkpoints/name/web/index.html`. If you set `--display_id` > 0, we will use a browser-based graphics server. You need to call `th -ldisplay.start 8000 0.0.0.0` to start the server. See [[szym/display]](https://github.com/szym/display) for more details.
+- During training, you can visualize the result of current training. If you set `--display_id 0`, we will periodically save the training results to `[opt.checkpoints_dir]/[opt.name]/web/`. If you set `--display_id` > 0, the results will be shown on a local graphics web server launched by [szym/display: a lightweight display server for Torch](https://github.com/szym/display). To do this, you should have Torch, Python 3, and the display package installed. You need to invoke `th -ldisplay.start 8000 0.0.0.0` to start the server.
### CycleGAN Datasets
Download the CycleGAN datasets using the following script:
diff --git a/models/networks.py b/models/networks.py
index edbe972..9cb6222 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -26,18 +26,19 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm, gpu_ids=[]):
norm_layer = InstanceNormalization
else:
print('normalization layer [%s] is not found' % norm)
- if use_gpu:
- assert(torch.cuda.is_available())
-
+
+ assert(torch.cuda.is_available() == use_gpu)
if which_model_netG == 'resnet_9blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=9, gpu_ids=gpu_ids)
elif which_model_netG == 'resnet_6blocks':
- netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, n_blocks=6, gpu_ids=gpu_ids)
- elif which_model_netG == 'unet':
- netG = UnetGenerator(input_nc, output_nc, ngf, norm_layer, gpu_ids=gpu_ids)
+ netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer, 6, gpu_ids=gpu_ids)
+ elif which_model_netG == 'unet_128':
+ netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer, gpu_ids=gpu_ids)
+ elif which_model_netG == 'unet_256':
+ netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer, gpu_ids=gpu_ids)
else:
print('Generator model name [%s] is not recognized' % which_model_netG)
- if use_gpu:
+ if len(gpu_ids) > 0:
netG.cuda()
netG.apply(weights_init)
return netG
@@ -47,9 +48,7 @@ def define_D(input_nc, ndf, which_model_netD,
n_layers_D=3, use_sigmoid=False, gpu_ids=[]):
netD = None
use_gpu = len(gpu_ids) > 0
- if use_gpu:
- assert(torch.cuda.is_available())
-
+ assert(torch.cuda.is_available() == use_gpu)
if which_model_netD == 'basic':
netD = define_D(input_nc, ndf, 'n_layers', use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
elif which_model_netD == 'n_layers':
@@ -143,7 +142,7 @@ class ResnetGenerator(nn.Module):
mult = 2**n_downsampling
for i in range(n_blocks):
- model += [Resnet_block(ngf * mult, 'zero', norm_layer=norm_layer)]
+ model += [ResnetBlock(ngf * mult, 'zero', norm_layer=norm_layer)]
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
@@ -165,11 +164,10 @@ class ResnetGenerator(nn.Module):
return self.model(input)
-
# Define a resnet block
-class Resnet_block(nn.Module):
+class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer):
- super(Resnet_block, self).__init__()
+ super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer)
def build_conv_block(self, dim, padding_type, norm_layer):
@@ -193,15 +191,35 @@ class Resnet_block(nn.Module):
return out
-# Defines the Unet geneator.
+# Defines the Unet generator.
+# |num_downs|: number of downsamplings in UNet. For example,
+# if |num_downs| == 7, image of size 128x128 will become of size 1x1
+# at the bottleneck
class UnetGenerator(nn.Module):
- def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, gpu_ids=[]):
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
+ norm_layer=nn.BatchNorm2d, gpu_ids=[]):
super(UnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.gpu_ids = gpu_ids
+ # currently support only input_nc == output_nc
+ assert(input_nc == output_nc)
+
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True)
+
+ for i in range(num_downs - 5):
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block)
+
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block)
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block)
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block)
+ unet_block = UnetSkipConnectionBlock(input_nc, ngf, unet_block,
+ outermost=True)
+
+ self.model = unet_block
+
def forward(self, input):
if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
@@ -209,6 +227,54 @@ class UnetGenerator(nn.Module):
return self.model(input)
+# Defines the submodule with skip connection.
+# X -------------------identity---------------------- X
+# |-- downsampling -- |submodule| -- upsampling --|
+class UnetSkipConnectionBlock(nn.Module):
+ def __init__(self, outer_nc, inner_nc,
+ submodule=None, outermost=False, innermost=False):
+ super(UnetSkipConnectionBlock, self).__init__()
+ self.outer_nc = outer_nc
+ self.inner_nc = inner_nc
+ self.innermost = innermost
+ downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
+ stride=2, padding=1)
+ downrelu = nn.LeakyReLU(0.2, True)
+ downnorm = nn.BatchNorm2d(inner_nc)
+ uprelu = nn.ReLU(True)
+ upnorm = nn.BatchNorm2d(outer_nc)
+
+ if outermost:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, output_padding=0)
+ down = [downconv]
+ up = [uprelu, upconv, nn.Tanh()]
+ model = down + [submodule] + up
+ elif innermost:
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, output_padding=0)
+ down = [downrelu, downconv]
+ up = [uprelu, upconv, upnorm]
+ model = down + up
+ else:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, output_padding=0)
+ down = [downrelu, downconv, downnorm]
+ up = [uprelu, upconv, upnorm]
+ model = down + [submodule] + up
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ #print(self.outer_nc, self.inner_nc, self.innermost)
+ #print(x.size())
+ #print(self.model(x).size())
+ return torch.cat([self.model(x), x], 1)
+
+
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
diff --git a/util/visualizer.py b/util/visualizer.py
index 0b8578e..4daf506 100644
--- a/util/visualizer.py
+++ b/util/visualizer.py
@@ -27,7 +27,7 @@ class Visualizer():
def display_current_results(self, visuals, epoch):
if self.display_id > 0: # show images in the browser
idx = 0
- for label, image_numpy in visuals:
+ for label, image_numpy in visuals.items():
image_numpy = np.flipud(image_numpy)
self.display.image(image_numpy, title=label,
win=self.display_id + idx)