diff options
Diffstat (limited to 'train.py')
| -rwxr-xr-x | train.py | 18 |
1 files changed, 9 insertions, 9 deletions
@@ -44,13 +44,13 @@ display_delta = total_steps % opt.display_freq print_delta = total_steps % opt.print_freq save_delta = total_steps % opt.save_latest_freq -print("{} {}".format(start_epoch, start_epoch + opt.niter + opt.niter_decay + 1)) +print("{} {}".format(start_epoch, start_epoch + opt.niter)) -for epoch in range(start_epoch, start_epoch + opt.niter + opt.niter_decay + 1): +for epoch in range(start_epoch, start_epoch + opt.niter): epoch_start_time = time.time() - if epoch != start_epoch: - epoch_iter = epoch_iter % dataset_size - for i, data in enumerate(dataset, start=epoch_iter): + # if epoch != start_epoch: + # epoch_iter = epoch_iter % dataset_size + for i, data in enumerate(dataset): #, start=epoch_iter iter_start_time = time.time() total_steps += opt.batchSize epoch_iter += opt.batchSize @@ -117,9 +117,9 @@ for epoch in range(start_epoch, start_epoch + opt.niter + opt.niter_decay + 1): np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') ### instead of only training the local enhancer, train the entire network after certain iterations - if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): - model.module.update_fixed_params() + # if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): + # model.module.update_fixed_params() ### linearly decay learning rate after certain iterations - if opt.niter != 0 and epoch > opt.niter: - model.module.update_learning_rate() + # if opt.niter != 0 and epoch > opt.niter: + # model.module.update_learning_rate() |
