diff options
| -rw-r--r-- | models/networks.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/models/networks.py b/models/networks.py index 6777359..8df7809 100644 --- a/models/networks.py +++ b/models/networks.py @@ -77,9 +77,10 @@ def print_network(net): ############################################################################## -# Defines the GAN loss used in LSGAN. -# It is basically same as MSELoss, but it abstracts away the need to create -# the target label tensor that has the same size as the input +# Defines the GAN loss which uses either LSGAN or the regular GAN. +# When LSGAN is used, it is basically same as MSELoss, +# but it abstracts away the need to create the target label tensor +# that has the same size as the input class GANLoss(nn.Module): def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor): @@ -307,7 +308,9 @@ class NLayerDiscriminator(nn.Module): ] sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=1)] - sequence += [nn.Sigmoid()] + + if use_sigmoid: + sequence += [nn.Sigmoid()] self.model = nn.Sequential(*sequence) |
