summaryrefslogtreecommitdiff
path: root/models/networks.py
diff options
context:
space:
mode:
Diffstat (limited to 'models/networks.py')
-rw-r--r--models/networks.py26
1 files changed, 10 insertions, 16 deletions
diff --git a/models/networks.py b/models/networks.py
index a7b6860..5624a7b 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -201,19 +201,15 @@ class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, gpu_ids=[]):
super(UnetGenerator, self).__init__()
- self.input_nc = input_nc
- self.output_nc = output_nc
- self.ngf = ngf
self.gpu_ids = gpu_ids
# currently support only input_nc == output_nc
assert(input_nc == output_nc)
+ # construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, innermost=True)
-
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block)
-
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block)
@@ -236,9 +232,8 @@ class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc,
submodule=None, outermost=False, innermost=False):
super(UnetSkipConnectionBlock, self).__init__()
- self.outer_nc = outer_nc
- self.inner_nc = inner_nc
- self.innermost = innermost
+ self.outermost = outermost
+
downconv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4,
stride=2, padding=1)
downrelu = nn.LeakyReLU(0.2, True)
@@ -249,21 +244,21 @@ class UnetSkipConnectionBlock(nn.Module):
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
- padding=1, output_padding=0)
+ padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
- padding=1, output_padding=0)
+ padding=1)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
- padding=1, output_padding=0)
+ padding=1)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
model = down + [submodule] + up
@@ -271,11 +266,10 @@ class UnetSkipConnectionBlock(nn.Module):
self.model = nn.Sequential(*model)
def forward(self, x):
- #print(self.outer_nc, self.inner_nc, self.innermost)
- #print(x.size())
- #print(self.model(x).size())
- return torch.cat([self.model(x), x], 1)
-
+ if self.outermost:
+ return self.model(x)
+ else:
+ return torch.cat([self.model(x), x], 1)
# Defines the PatchGAN discriminator with the specified arguments.