summaryrefslogtreecommitdiff
path: root/models/base_model.py
diff options
context:
space:
mode:
authortingchunw <tingchunw@nvidia.com>2018-02-09 18:40:19 +0000
committertingchunw <tingchunw@nvidia.com>2018-02-09 18:40:19 +0000
commit736a2dc9afef418820e9c52f4f3b38460360b9f2 (patch)
treeefc34cde211b65f6f310bb82d76e7892ca720c2c /models/base_model.py
parentedf910b1c1d02020b31782ab4c3b6ebf9af8c323 (diff)
fix python version issue
Diffstat (limited to 'models/base_model.py')
-rwxr-xr-xmodels/base_model.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/models/base_model.py b/models/base_model.py
index d3879d0..88e0587 100755
--- a/models/base_model.py
+++ b/models/base_model.py
@@ -2,6 +2,7 @@
### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
import os
import torch
+import sys
class BaseModel(torch.nn.Module):
def name(self):
@@ -70,8 +71,11 @@ class BaseModel(torch.nn.Module):
print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
except:
print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
- from sets import Set
- not_initialized = Set()
+ if sys.version_info >= (3,0):
+ not_initialized = set()
+ else:
+ from sets import Set
+ not_initialized = Set()
for k, v in pretrained_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v