diff options
| -rwxr-xr-x | augment.py | 1 | ||||
| -rwxr-xr-x | train.py | 47 |
2 files changed, 24 insertions, 24 deletions
@@ -72,6 +72,7 @@ if opt.which_epoch == 'latest': if os.path.exists(iter_path): try: current_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) + current_epoch = int(current_epoch) - 1 except: current_epoch, epoch_iter = 1, 0 print('Loading epoch %d' % (current_epoch,)) @@ -22,6 +22,7 @@ if opt.continue_train: print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) else: start_epoch, epoch_iter = 1, 0 + print('Initializing a new epoch') if opt.debug: opt.display_freq = 1 @@ -33,7 +34,7 @@ if opt.debug: data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() dataset_size = len(data_loader) -print('#training images = %d' % dataset_size) +print('dataset size = %d' % dataset_size) model = create_model(opt) visualizer = Visualizer(opt) @@ -44,7 +45,7 @@ display_delta = total_steps % opt.display_freq print_delta = total_steps % opt.print_freq save_delta = total_steps % opt.save_latest_freq -print("{} {}".format(start_epoch, start_epoch + opt.niter)) +print("epoch = {} {}".format(start_epoch, start_epoch + opt.niter)) for epoch in range(start_epoch, start_epoch + opt.niter): epoch_start_time = time.time() @@ -89,37 +90,35 @@ for epoch in range(start_epoch, start_epoch + opt.niter): if total_steps % opt.print_freq == print_delta: errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()} t = (time.time() - iter_start_time) / opt.batchSize - visualizer.print_current_errors(epoch, epoch_iter, errors, t) - visualizer.plot_current_errors(errors, total_steps) + #visualizer.print_current_errors(epoch, epoch_iter, errors, t) + #visualizer.plot_current_errors(errors, total_steps) ### display output images - if save_fake: - visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), - ('synthesized_image', util.tensor2im(generated.data[0])), - ('real_image', util.tensor2im(data['image'][0]))]) - visualizer.display_current_results(visuals, epoch, total_steps) + #if save_fake: + # visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), + # ('synthesized_image', util.tensor2im(generated.data[0])), + # ('real_image', util.tensor2im(data['image'][0]))]) + # visualizer.display_current_results(visuals, epoch, total_steps) ### save latest model if total_steps % opt.save_latest_freq == save_delta: print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) model.module.save('latest') - np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') - - # end of epoch + model.module.save(str(epoch)) + np.savetxt(iter_path, (int(epoch)+1, epoch_iter), delimiter=',', fmt='%d') iter_end_time = time.time() - print('End of epoch %d / %d \t Time Taken: %d sec' % - (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) - - ### save model for this epoch + print('End of epoch %d \t Time Taken: %d sec' % (epoch, time.time() - epoch_start_time)) print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) + model.module.save(str(epoch)) + np.savetxt(iter_path, (int(epoch)+1, 0), delimiter=',', fmt='%d') model.module.save('latest') - model.module.save(epoch) - np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') - ### instead of only training the local enhancer, train the entire network after certain iterations - # if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): - # model.module.update_fixed_params() - ### linearly decay learning rate after certain iterations - # if opt.niter != 0 and epoch > opt.niter: - # model.module.update_learning_rate() +# ### instead of only training the local enhancer, train the entire network after certain iterations +# # if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): +# # model.module.update_fixed_params() +# +# ### linearly decay learning rate after certain iterations +# # if opt.niter != 0 and epoch > opt.niter: +# # model.module.update_learning_rate() + |
