diff options
| author | Taesung Park <taesung_park@berkeley.edu> | 2017-12-10 23:04:41 -0800 |
|---|---|---|
| committer | Taesung Park <taesung_park@berkeley.edu> | 2017-12-10 23:04:41 -0800 |
| commit | f33f098be9b25c3b62523540c9c703af1db0b1c0 (patch) | |
| tree | 9b51e547067b46ad8b55ddb34b207825550df867 /models/base_model.py | |
| parent | 3d2c534933b356dc313a620639a713cb940dc756 (diff) | |
| parent | 2d96edbee5a488a7861833731a2cb71b23b55727 (diff) | |
merged conflicts
Diffstat (limited to 'models/base_model.py')
| -rw-r--r-- | models/base_model.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/models/base_model.py b/models/base_model.py index 446a903..9b55afe 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -44,13 +44,14 @@ class BaseModel(): save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if len(gpu_ids) and torch.cuda.is_available(): - network.cuda(device_id=gpu_ids[0]) + network.cuda(gpu_ids[0]) # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) network.load_state_dict(torch.load(save_path)) + # update learning rate (called once every epoch) def update_learning_rate(self): for scheduler in self.schedulers: |
