summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rwxr-xr-xtrain.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/train.py b/train.py
index 15f0e37..c2f7ae1 100755
--- a/train.py
+++ b/train.py
@@ -111,11 +111,10 @@ for epoch in range(start_epoch, start_epoch + opt.niter):
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
### save model for this epoch
- if epoch % opt.save_epoch_freq == 0:
- print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
- model.module.save('latest')
- model.module.save(epoch)
- np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
+ print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
+ model.module.save('latest')
+ model.module.save(epoch)
+ np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
### instead of only training the local enhancer, train the entire network after certain iterations
# if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):