diff options
| -rwxr-xr-x | models/base_model.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/models/base_model.py b/models/base_model.py index d3879d0..88e0587 100755 --- a/models/base_model.py +++ b/models/base_model.py @@ -2,6 +2,7 @@ ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). import os import torch +import sys class BaseModel(torch.nn.Module): def name(self): @@ -70,8 +71,11 @@ class BaseModel(torch.nn.Module): print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) except: print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) - from sets import Set - not_initialized = Set() + if sys.version_info >= (3,0): + not_initialized = set() + else: + from sets import Set + not_initialized = Set() for k, v in pretrained_dict.items(): if v.size() == model_dict[k].size(): model_dict[k] = v |
