From 1e0fbeaf5bc81c0b9e23655dc31d28c77505b7cd Mon Sep 17 00:00:00 2001 From: LambdaWill <574819595@qq.com> Date: Sat, 21 Oct 2017 12:34:44 +0800 Subject: Compatible with the latest version. Since the commit `Change device_id to device in python land #3133`, keyword `device_id` has been changed to `device` --- models/networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'models/networks.py') diff --git a/models/networks.py b/models/networks.py index 51e3f25..dca5489 100644 --- a/models/networks.py +++ b/models/networks.py @@ -118,7 +118,7 @@ def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropo else: raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: - netG.cuda(device_id=gpu_ids[0]) + netG.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version. init_weights(netG, init_type=init_type) return netG @@ -139,7 +139,7 @@ def define_D(input_nc, ndf, which_model_netD, raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) if use_gpu: - netD.cuda(device_id=gpu_ids[0]) + netD.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version. init_weights(netD, init_type=init_type) return netD -- cgit v1.2.3-70-g09d2 From 9d0295bbb5d01ce40bd2792f72882128282f0338 Mon Sep 17 00:00:00 2001 From: LambdaWill <574819595@qq.com> Date: Sat, 21 Oct 2017 17:06:30 +0800 Subject: Update networks.py fix typo --- models/networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'models/networks.py') diff --git a/models/networks.py b/models/networks.py index dca5489..949659d 100644 --- a/models/networks.py +++ b/models/networks.py @@ -139,7 +139,7 @@ def define_D(input_nc, ndf, which_model_netD, raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) if use_gpu: - netD.cuda(device_id=gpu_ids[0]) # or netG.cuda(device=gpu_ids[0]) for latest version. + netD.cuda(device_id=gpu_ids[0]) # or netD.cuda(device=gpu_ids[0]) for latest version. init_weights(netD, init_type=init_type) return netD -- cgit v1.2.3-70-g09d2