diff options
| author | junyanz <junyanz@berkeley.edu> | 2017-04-22 04:23:14 -0700 |
|---|---|---|
| committer | junyanz <junyanz@berkeley.edu> | 2017-04-22 04:23:14 -0700 |
| commit | a7917caaeaefe51db959b8f3ae50a20e726fbd93 (patch) | |
| tree | 0aa8ac95130314d26e67ca0be9414df490965f84 /models/networks.py | |
| parent | 318936a415814e4b7b7affd3efd69c2aa1331074 (diff) | |
fix the padding/kernel_size in D
Diffstat (limited to 'models/networks.py')
| -rw-r--r-- | models/networks.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/models/networks.py b/models/networks.py index 8df7809..2e3ad79 100644 --- a/models/networks.py +++ b/models/networks.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn from torch.autograd import Variable from pdb import set_trace as st +import numpy as np ############################################################################### # Functions @@ -279,8 +280,9 @@ class NLayerDiscriminator(nn.Module): self.gpu_ids = gpu_ids kw = 4 + padw = int(np.ceil((kw-1)/2)) sequence = [ - nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=2), + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True) ] @@ -291,7 +293,7 @@ class NLayerDiscriminator(nn.Module): nf_mult = min(2**n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, - kernel_size=kw, stride=2, padding=2), + kernel_size=kw, stride=2, padding=padw), # TODO: use InstanceNorm nn.BatchNorm2d(ndf * nf_mult), nn.LeakyReLU(0.2, True) @@ -301,13 +303,13 @@ class NLayerDiscriminator(nn.Module): nf_mult = min(2**n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, - kernel_size=1, stride=2, padding=2), + kernel_size=kw, stride=1, padding=padw), # TODO: useInstanceNorm nn.BatchNorm2d(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] - sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=1)] + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] if use_sigmoid: sequence += [nn.Sigmoid()] |
