summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/cycle_gan_model.py16
-rw-r--r--models/models.py1
-rw-r--r--models/networks.py5
-rw-r--r--models/pix2pix_model.py3
4 files changed, 6 insertions, 19 deletions
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index b7b840d..85432bb 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -1,6 +1,4 @@
-import numpy as np
import torch
-import os
from collections import OrderedDict
from torch.autograd import Variable
import itertools
@@ -8,7 +6,6 @@ import util.util as util
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
-import sys
class CycleGANModel(BaseModel):
@@ -17,10 +14,6 @@ class CycleGANModel(BaseModel):
def initialize(self, opt):
BaseModel.initialize(self, opt)
-
- nb = opt.batchSize
- size = opt.fineSize
-
# load/define networks
# The naming conversion is different from those used in the paper
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
@@ -47,7 +40,6 @@ class CycleGANModel(BaseModel):
self.load_network(self.netD_B, 'D_B', which_epoch)
if self.isTrain:
- self.old_lr = opt.lr
self.fake_A_pool = ImagePool(opt.pool_size)
self.fake_B_pool = ImagePool(opt.pool_size)
# define loss functions
@@ -129,7 +121,7 @@ class CycleGANModel(BaseModel):
self.loss_D_B = loss_D_B.data[0]
def backward_G(self):
- lambda_idt = self.opt.identity
+ lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
@@ -200,8 +192,8 @@ class CycleGANModel(BaseModel):
def get_current_errors(self):
ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A),
- ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
- if self.opt.identity > 0.0:
+ ('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
+ if self.opt.lambda_identity > 0.0:
ret_errors['idt_A'] = self.loss_idt_A
ret_errors['idt_B'] = self.loss_idt_B
return ret_errors
@@ -215,7 +207,7 @@ class CycleGANModel(BaseModel):
rec_B = util.tensor2im(self.rec_B)
ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
- if self.opt.isTrain and self.opt.identity > 0.0:
+ if self.opt.isTrain and self.opt.lambda_identity > 0.0:
ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
return ret_visuals
diff --git a/models/models.py b/models/models.py
index d5bb9d8..39cc020 100644
--- a/models/models.py
+++ b/models/models.py
@@ -1,4 +1,3 @@
-
def create_model(opt):
model = None
print(opt.model)
diff --git a/models/networks.py b/models/networks.py
index da2f59c..b118c6a 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -4,7 +4,6 @@ from torch.nn import init
import functools
from torch.autograd import Variable
from torch.optim import lr_scheduler
-import numpy as np
###############################################################################
# Functions
###############################################################################
@@ -434,6 +433,7 @@ class NLayerDiscriminator(nn.Module):
else:
return self.model(input)
+
class PixelDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
super(PixelDiscriminator, self).__init__()
@@ -442,7 +442,7 @@ class PixelDiscriminator(nn.Module):
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
-
+
self.net = [
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
nn.LeakyReLU(0.2, True),
@@ -461,4 +461,3 @@ class PixelDiscriminator(nn.Module):
return nn.parallel.data_parallel(self.net, input, self.gpu_ids)
else:
return self.net(input)
-
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
index 74a941e..78f8d69 100644
--- a/models/pix2pix_model.py
+++ b/models/pix2pix_model.py
@@ -1,6 +1,4 @@
-import numpy as np
import torch
-import os
from collections import OrderedDict
from torch.autograd import Variable
import util.util as util
@@ -32,7 +30,6 @@ class Pix2PixModel(BaseModel):
if self.isTrain:
self.fake_AB_pool = ImagePool(opt.pool_size)
- self.old_lr = opt.lr
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionL1 = torch.nn.L1Loss()