summaryrefslogtreecommitdiff
path: root/models
diff options
context:
space:
mode:
Diffstat (limited to 'models')
-rw-r--r--models/base_model.py1
-rw-r--r--models/cycle_gan_model.py3
-rw-r--r--models/models.py9
-rw-r--r--models/networks.py9
-rw-r--r--models/one_direction_test_model.py51
-rw-r--r--models/pix2pix_model.py5
6 files changed, 64 insertions, 14 deletions
diff --git a/models/base_model.py b/models/base_model.py
index ce18635..9b92bb4 100644
--- a/models/base_model.py
+++ b/models/base_model.py
@@ -1,6 +1,5 @@
import os
import torch
-from pdb import set_trace as st
class BaseModel():
def name(self):
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
index d361e47..451002d 100644
--- a/models/cycle_gan_model.py
+++ b/models/cycle_gan_model.py
@@ -2,7 +2,6 @@ import numpy as np
import torch
import os
from collections import OrderedDict
-from pdb import set_trace as st
from torch.autograd import Variable
import itertools
import util.util as util
@@ -72,7 +71,7 @@ class CycleGANModel(BaseModel):
print('-----------------------------------------------')
def set_input(self, input):
- AtoB = self.opt.which_direction is 'AtoB'
+ AtoB = self.opt.which_direction == 'AtoB'
input_A = input['A' if AtoB else 'B']
input_B = input['B' if AtoB else 'A']
self.input_A.resize_(input_A.size()).copy_(input_A)
diff --git a/models/models.py b/models/models.py
index 7e790d0..8fea4f4 100644
--- a/models/models.py
+++ b/models/models.py
@@ -4,12 +4,17 @@ def create_model(opt):
print(opt.model)
if opt.model == 'cycle_gan':
from .cycle_gan_model import CycleGANModel
- assert(opt.align_data == False)
+ #assert(opt.align_data == False)
model = CycleGANModel()
- if opt.model == 'pix2pix':
+ elif opt.model == 'pix2pix':
from .pix2pix_model import Pix2PixModel
assert(opt.align_data == True)
model = Pix2PixModel()
+ elif opt.model == 'one_direction_test':
+ from .one_direction_test_model import OneDirectionTestModel
+ model = OneDirectionTestModel()
+ else:
+ raise ValueError("Model [%s] not recognized." % opt.model)
model.initialize(opt)
print("model [%s] was created" % (model.name()))
return model
diff --git a/models/networks.py b/models/networks.py
index 60e1777..b0f3b11 100644
--- a/models/networks.py
+++ b/models/networks.py
@@ -1,7 +1,6 @@
import torch
import torch.nn as nn
from torch.autograd import Variable
-from pdb import set_trace as st
import numpy as np
###############################################################################
@@ -13,7 +12,7 @@ def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
- elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1:
+ elif classname.find('BatchNorm2d') != -1 or classname.find('InstanceNormalization') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
@@ -162,7 +161,7 @@ class ResnetGenerator(nn.Module):
self.model = nn.Sequential(*model)
def forward(self, input):
- if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
@@ -222,7 +221,7 @@ class UnetGenerator(nn.Module):
self.model = unet_block
def forward(self, input):
- if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
@@ -323,7 +322,7 @@ class NLayerDiscriminator(nn.Module):
self.model = nn.Sequential(*sequence)
def forward(self, input):
- if isinstance(input.data, torch.cuda.FloatTensor) and self.gpu_ids:
+ if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
diff --git a/models/one_direction_test_model.py b/models/one_direction_test_model.py
new file mode 100644
index 0000000..d4f6442
--- /dev/null
+++ b/models/one_direction_test_model.py
@@ -0,0 +1,51 @@
+from torch.autograd import Variable
+from collections import OrderedDict
+import util.util as util
+from .base_model import BaseModel
+from . import networks
+
+
+class OneDirectionTestModel(BaseModel):
+ def name(self):
+ return 'OneDirectionTestModel'
+
+ def initialize(self, opt):
+ BaseModel.initialize(self, opt)
+
+ nb = opt.batchSize
+ size = opt.fineSize
+ self.input_A = self.Tensor(nb, opt.input_nc, size, size)
+
+ assert(not self.isTrain)
+ self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
+ opt.ngf, opt.which_model_netG,
+ opt.norm, opt.use_dropout,
+ self.gpu_ids)
+ which_epoch = opt.which_epoch
+ #AtoB = self.opt.which_direction == 'AtoB'
+ #which_network = 'G_A' if AtoB else 'G_B'
+ self.load_network(self.netG_A, 'G', which_epoch)
+
+ print('---------- Networks initialized -------------')
+ networks.print_network(self.netG_A)
+ print('-----------------------------------------------')
+
+ def set_input(self, input):
+ AtoB = self.opt.which_direction == 'AtoB'
+ input_A = input['A' if AtoB else 'B']
+ self.input_A.resize_(input_A.size()).copy_(input_A)
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
+
+ def test(self):
+ self.real_A = Variable(self.input_A)
+ self.fake_B = self.netG_A.forward(self.real_A)
+
+ #get image paths
+ def get_image_paths(self):
+ return self.image_paths
+
+ def get_current_visuals(self):
+ real_A = util.tensor2im(self.real_A.data)
+ fake_B = util.tensor2im(self.fake_B.data)
+ return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])
+
diff --git a/models/pix2pix_model.py b/models/pix2pix_model.py
index 0e02ebf..34e0bac 100644
--- a/models/pix2pix_model.py
+++ b/models/pix2pix_model.py
@@ -2,13 +2,11 @@ import numpy as np
import torch
import os
from collections import OrderedDict
-from pdb import set_trace as st
from torch.autograd import Variable
import util.util as util
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
-from pdb import set_trace as st
class Pix2PixModel(BaseModel):
def name(self):
@@ -55,7 +53,7 @@ class Pix2PixModel(BaseModel):
print('-----------------------------------------------')
def set_input(self, input):
- AtoB = self.opt.which_direction is 'AtoB'
+ AtoB = self.opt.which_direction == 'AtoB'
input_A = input['A' if AtoB else 'B']
input_B = input['B' if AtoB else 'A']
self.input_A.resize_(input_A.size()).copy_(input_A)
@@ -108,7 +106,6 @@ class Pix2PixModel(BaseModel):
self.loss_G.backward()
def optimize_parameters(self):
- # st()
self.forward()
self.optimizer_D.zero_grad()