1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
|
import torch
from collections import OrderedDict
from torch.autograd import Variable
import itertools
import util.util as util
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
class CycleGANModel(BaseModel):
def name(self):
return 'CycleGANModel'
def initialize(self, opt):
BaseModel.initialize(self, opt)
# load/define networks
# The naming conversion is different from those used in the paper
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
if self.isTrain:
use_sigmoid = opt.no_lsgan
self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
if not self.isTrain or opt.continue_train:
which_epoch = opt.which_epoch
self.load_network(self.netG_A, 'G_A', which_epoch)
self.load_network(self.netG_B, 'G_B', which_epoch)
if self.isTrain:
self.load_network(self.netD_A, 'D_A', which_epoch)
self.load_network(self.netD_B, 'D_B', which_epoch)
if self.isTrain:
self.fake_A_pool = ImagePool(opt.pool_size)
self.fake_B_pool = ImagePool(opt.pool_size)
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
# initialize optimizers
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = []
self.schedulers = []
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D_A)
self.optimizers.append(self.optimizer_D_B)
for optimizer in self.optimizers:
self.schedulers.append(networks.get_scheduler(optimizer, opt))
print('---------- Networks initialized -------------')
networks.print_network(self.netG_A)
networks.print_network(self.netG_B)
if self.isTrain:
networks.print_network(self.netD_A)
networks.print_network(self.netD_B)
print('-----------------------------------------------')
def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
input_A = input['A' if AtoB else 'B']
input_B = input['B' if AtoB else 'A']
if len(self.gpu_ids) > 0:
input_A = input_A.cuda(self.gpu_ids[0], async=True)
input_B = input_B.cuda(self.gpu_ids[0], async=True)
self.input_A = input_A
self.input_B = input_B
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
self.real_A = Variable(self.input_A)
self.real_B = Variable(self.input_B)
def test(self):
real_A = Variable(self.input_A, volatile=True)
fake_B = self.netG_A(real_A)
self.rec_A = self.netG_B(fake_B).data
self.fake_B = fake_B.data
real_B = Variable(self.input_B, volatile=True)
fake_A = self.netG_B(real_B)
self.rec_B = self.netG_A(fake_A).data
self.fake_A = fake_A.data
# get image paths
def get_image_paths(self):
return self.image_paths
def backward_D_basic(self, netD, real, fake):
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
# backward
loss_D.backward()
return loss_D
def backward_D_A(self):
fake_B = self.fake_B_pool.query(self.fake_B)
loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
self.loss_D_A = loss_D_A.data[0]
def backward_D_B(self):
fake_A = self.fake_A_pool.query(self.fake_A)
loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
self.loss_D_B = loss_D_B.data[0]
def backward_G(self):
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed.
idt_A = self.netG_A(self.real_B)
loss_idt_A = self.criterionIdt(idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed.
idt_B = self.netG_B(self.real_A)
loss_idt_B = self.criterionIdt(idt_B, self.real_A) * lambda_A * lambda_idt
self.idt_A = idt_A.data
self.idt_B = idt_B.data
self.loss_idt_A = loss_idt_A.data[0]
self.loss_idt_B = loss_idt_B.data[0]
else:
loss_idt_A = 0
loss_idt_B = 0
self.loss_idt_A = 0
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
fake_B = self.netG_A(self.real_A)
pred_fake = self.netD_A(fake_B)
loss_G_A = self.criterionGAN(pred_fake, True)
# GAN loss D_B(G_B(B))
fake_A = self.netG_B(self.real_B)
pred_fake = self.netD_B(fake_A)
loss_G_B = self.criterionGAN(pred_fake, True)
# Forward cycle loss
rec_A = self.netG_B(fake_B)
loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * lambda_A
# Backward cycle loss
rec_B = self.netG_A(fake_A)
loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * lambda_B
# combined loss
loss_G = loss_G_A + loss_G_B + loss_cycle_A + loss_cycle_B + loss_idt_A + loss_idt_B
loss_G.backward()
self.fake_B = fake_B.data
self.fake_A = fake_A.data
self.rec_A = rec_A.data
self.rec_B = rec_B.data
self.loss_G_A = loss_G_A.data[0]
self.loss_G_B = loss_G_B.data[0]
self.loss_cycle_A = loss_cycle_A.data[0]
self.loss_cycle_B = loss_cycle_B.data[0]
def optimize_parameters(self):
# forward
self.forward()
# G_A and G_B
self.optimizer_G.zero_grad()
self.backward_G()
self.optimizer_G.step()
# D_A
self.optimizer_D_A.zero_grad()
self.backward_D_A()
self.optimizer_D_A.step()
# D_B
self.optimizer_D_B.zero_grad()
self.backward_D_B()
self.optimizer_D_B.step()
def get_current_errors(self):
ret_errors = OrderedDict([('D_A', self.loss_D_A), ('G_A', self.loss_G_A), ('Cyc_A', self.loss_cycle_A),
('D_B', self.loss_D_B), ('G_B', self.loss_G_B), ('Cyc_B', self.loss_cycle_B)])
if self.opt.lambda_identity > 0.0:
ret_errors['idt_A'] = self.loss_idt_A
ret_errors['idt_B'] = self.loss_idt_B
return ret_errors
def get_current_visuals(self):
real_A = util.tensor2im(self.input_A)
fake_B = util.tensor2im(self.fake_B)
rec_A = util.tensor2im(self.rec_A)
real_B = util.tensor2im(self.input_B)
fake_A = util.tensor2im(self.fake_A)
rec_B = util.tensor2im(self.rec_B)
ret_visuals = OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A),
('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B)])
if self.opt.isTrain and self.opt.lambda_identity > 0.0:
ret_visuals['idt_A'] = util.tensor2im(self.idt_A)
ret_visuals['idt_B'] = util.tensor2im(self.idt_B)
return ret_visuals
def save(self, label):
self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
|