From 1e0fbeaf5bc81c0b9e23655dc31d28c77505b7cd Mon Sep 17 00:00:00 2001 From: LambdaWill <574819595@qq.com> Date: Sat, 21 Oct 2017 12:34:44 +0800 Subject: Compatible with the latest version. Since the commit `Change device_id to device in python land #3133`, keyword `device_id` has been changed to `device` --- models/networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'models') diff --git a/models/networks.py b/models/networks.py index 51e3f25..dca5489 100644 --- a/models/networks.py +++ b/models/networks.py @@ -118,7 +118,7 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo else: raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: - netG.cuda(device_id=gpu_ids[0]) + netG.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version. init_weights(netG, init_type=init_type) return netG @@ -139,7 +139,7 @@ def define_D(input_nc, ndf, which_model_netD, raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) if use_gpu: - netD.cuda(device_id=gpu_ids[0]) + netD.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version. init_weights(netD, init_type=init_type) return netD -- cgit v1.2.3-70-g09d2 From 9d0295bbb5d01ce40bd2792f72882128282f0338 Mon Sep 17 00:00:00 2001 From: LambdaWill <574819595@qq.com> Date: Sat, 21 Oct 2017 17:06:30 +0800 Subject: Update networks.py fix typo --- models/networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'models') diff --git a/models/networks.py b/models/networks.py index dca5489..949659d 100644 --- a/models/networks.py +++ b/models/networks.py @@ -139,7 +139,7 @@ def define_D(input_nc, ndf, which_model_netD, raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) if use_gpu: - netD.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version. + netD.cuda(device_id=gpu_ids[0]) # or netD.cuda(device=gpu_ids[0]) for latest version. init_weights(netD, init_type=init_type) return netD -- cgit v1.2.3-70-g09d2 From 343dda259b4b6a64a338fa61a9f1c70893b8fc7e Mon Sep 17 00:00:00 2001 From: LambdaWill <574819595@qq.com> Date: Tue, 24 Oct 2017 12:21:00 +0800 Subject: Update base_model.py --- models/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'models') diff --git a/models/base_model.py b/models/base_model.py index 446a903..d62d189 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -44,7 +44,7 @@ class BaseModel(): save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if len(gpu_ids) and torch.cuda.is_available(): - network.cuda(device_id=gpu_ids[0]) + network.cuda(device_id=gpu_ids[0]) # network.cuda(device=gpu_ids[0]) for the latest version. # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch_label): -- cgit v1.2.3-70-g09d2