diff options
| author | Taesung Park <taesung_park@berkeley.edu> | 2017-04-18 04:12:53 -0700 |
|---|---|---|
| committer | Taesung Park <taesung_park@berkeley.edu> | 2017-04-18 04:12:53 -0700 |
| commit | b3c7f2fb04dbd3dba3e89edeaf432925993bd303 (patch) | |
| tree | 51a629ebde350ab2175dd4733eb4760c6bf0a0c5 /models/networks.py | |
| parent | c039c2596d61f70382e78a6d16203ce820572585 (diff) | |
Fixed bugs in the Unet structure
Diffstat (limited to 'models/networks.py')
| -rw-r--r-- | models/networks.py | 26 |
1 files changed, 10 insertions, 16 deletions
diff --git a/models/networks.py b/models/networks.py index a7b6860..5624a7b 100644 --- a/models/networks.py +++ b/models/networks.py @@ -201,19 +201,15 @@ class UnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, gpu_ids=[]): super(UnetGenerator, self).__init__() - self.input_nc = input_nc - self.output_nc = output_nc - self.ngf = ngf self.gpu_ids = gpu_ids # currently support only input_nc == output_nc assert(input_nc == output_nc) + # construct unet structure unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True) - for i in range(num_downs - 5): unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block) - unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block) unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block) unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block) @@ -236,9 +232,8 @@ class UnetSkipConnectionBlock(nn.Module): def __init__(self, outer_nc, inner_nc, submodule=None, outermost=False, innermost=False): super(UnetSkipConnectionBlock, self).__init__() - self.outer_nc = outer_nc - self.inner_nc = inner_nc - self.innermost = innermost + self.outermost = outermost + downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, stride=2, padding=1) downrelu = nn.LeakyReLU(0.2, True) @@ -249,21 +244,21 @@ class UnetSkipConnectionBlock(nn.Module): if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, - padding=1, output_padding=0) + padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, - padding=1, output_padding=0) + padding=1) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, - padding=1, output_padding=0) + padding=1) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] model = down + [submodule] + up @@ -271,11 +266,10 @@ class UnetSkipConnectionBlock(nn.Module): self.model = nn.Sequential(*model) def forward(self, x): - #print(self.outer_nc, self.inner_nc, self.innermost) - #print(x.size()) - #print(self.model(x).size()) - return torch.cat([self.model(x), x], 1) - + if self.outermost: + return self.model(x) + else: + return torch.cat([self.model(x), x], 1) # Defines the PatchGAN discriminator with the specified arguments. |
