summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/train.py b/train.py
index 92ef89d..be23dd6 100755
--- a/train.py
+++ b/train.py
@@ -44,13 +44,13 @@ display_delta = total_steps % opt.display_freq
print_delta = total_steps % opt.print_freq
save_delta = total_steps % opt.save_latest_freq
-print("{} {}".format(start_epoch, start_epoch + opt.niter + opt.niter_decay + 1))
+print("{} {}".format(start_epoch, start_epoch + opt.niter))
-for epoch in range(start_epoch, start_epoch + opt.niter + opt.niter_decay + 1):
+for epoch in range(start_epoch, start_epoch + opt.niter):
epoch_start_time = time.time()
- if epoch != start_epoch:
- epoch_iter = epoch_iter % dataset_size
- for i, data in enumerate(dataset, start=epoch_iter):
+ # if epoch != start_epoch:
+ # epoch_iter = epoch_iter % dataset_size
+ for i, data in enumerate(dataset): #, start=epoch_iter
iter_start_time = time.time()
total_steps += opt.batchSize
epoch_iter += opt.batchSize
@@ -117,9 +117,9 @@ for epoch in range(start_epoch, start_epoch + opt.niter + opt.niter_decay + 1):
np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
### instead of only training the local enhancer, train the entire network after certain iterations
- if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
- model.module.update_fixed_params()
+ # if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
+ # model.module.update_fixed_params()
### linearly decay learning rate after certain iterations
- if opt.niter != 0 and epoch > opt.niter:
- model.module.update_learning_rate()
+ # if opt.niter != 0 and epoch > opt.niter:
+ # model.module.update_learning_rate()