From 1c5c2f50da5ae101077c27fdac2a12fb1619ec86 Mon Sep 17 00:00:00 2001 From: SsnL Date: Sat, 13 Jan 2018 23:04:43 -0500 Subject: fix resize_ issue #170 --- models/cycle_gan_model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'models/cycle_gan_model.py') diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py index fe06823..b7b840d 100644 --- a/models/cycle_gan_model.py +++ b/models/cycle_gan_model.py @@ -20,8 +20,6 @@ class CycleGANModel(BaseModel): nb = opt.batchSize size = opt.fineSize - self.input_A = self.Tensor(nb, opt.input_nc, size, size) - self.input_B = self.Tensor(nb, opt.output_nc, size, size) # load/define networks # The naming conversion is different from those used in the paper @@ -81,8 +79,11 @@ class CycleGANModel(BaseModel): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] - self.input_A.resize_(input_A.size()).copy_(input_A) - self.input_B.resize_(input_B.size()).copy_(input_B) + if len(self.gpu_ids) > 0: + input_A = input_A.cuda(self.gpu_ids[0], async=True) + input_B = input_B.cuda(self.gpu_ids[0], async=True) + self.input_A = input_A + self.input_B = input_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): -- cgit v1.2.3-70-g09d2