summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--models/networks.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/models/networks.py b/models/networks.py
index 6777359..8df7809 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -77,9 +77,10 @@ def print_network(net):
##############################################################################
-# Defines the GAN loss used in LSGAN.
-# It is basically same as MSELoss, but it abstracts away the need to create
-# the target label tensor that has the same size as the input
+# Defines the GAN loss which uses either LSGAN or the regular GAN.
+# When LSGAN is used, it is basically same as MSELoss,
+# but it abstracts away the need to create the target label tensor
+# that has the same size as the input
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
@@ -307,7 +308,9 @@ class NLayerDiscriminator(nn.Module):
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=1)]
- sequence += [nn.Sigmoid()]
+
+ if use_sigmoid:
+ sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)