summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xaugment.py1
-rwxr-xr-xtrain.py47
2 files changed, 24 insertions, 24 deletions
diff --git a/augment.py b/augment.py
index cb75880..b5894e1 100755
--- a/augment.py
+++ b/augment.py
@@ -72,6 +72,7 @@ if opt.which_epoch == 'latest':
if os.path.exists(iter_path):
try:
current_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
+ current_epoch = int(current_epoch) - 1
except:
current_epoch, epoch_iter = 1, 0
print('Loading epoch %d' % (current_epoch,))
diff --git a/train.py b/train.py
index c2f7ae1..2157440 100755
--- a/train.py
+++ b/train.py
@@ -22,6 +22,7 @@ if opt.continue_train:
print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
else:
start_epoch, epoch_iter = 1, 0
+ print('Initializing a new epoch')
if opt.debug:
opt.display_freq = 1
@@ -33,7 +34,7 @@ if opt.debug:
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
-print('#training images = %d' % dataset_size)
+print('dataset size = %d' % dataset_size)
model = create_model(opt)
visualizer = Visualizer(opt)
@@ -44,7 +45,7 @@ display_delta = total_steps % opt.display_freq
print_delta = total_steps % opt.print_freq
save_delta = total_steps % opt.save_latest_freq
-print("{} {}".format(start_epoch, start_epoch + opt.niter))
+print("epoch = {} {}".format(start_epoch, start_epoch + opt.niter))
for epoch in range(start_epoch, start_epoch + opt.niter):
epoch_start_time = time.time()
@@ -89,37 +90,35 @@ for epoch in range(start_epoch, start_epoch + opt.niter):
if total_steps % opt.print_freq == print_delta:
errors = {k: v.data[0] if not isinstance(v, int) else v for k, v in loss_dict.items()}
t = (time.time() - iter_start_time) / opt.batchSize
- visualizer.print_current_errors(epoch, epoch_iter, errors, t)
- visualizer.plot_current_errors(errors, total_steps)
+ #visualizer.print_current_errors(epoch, epoch_iter, errors, t)
+ #visualizer.plot_current_errors(errors, total_steps)
### display output images
- if save_fake:
- visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
- ('synthesized_image', util.tensor2im(generated.data[0])),
- ('real_image', util.tensor2im(data['image'][0]))])
- visualizer.display_current_results(visuals, epoch, total_steps)
+ #if save_fake:
+ # visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
+ # ('synthesized_image', util.tensor2im(generated.data[0])),
+ # ('real_image', util.tensor2im(data['image'][0]))])
+ # visualizer.display_current_results(visuals, epoch, total_steps)
### save latest model
if total_steps % opt.save_latest_freq == save_delta:
print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
model.module.save('latest')
- np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
-
- # end of epoch
+ model.module.save(str(epoch))
+ np.savetxt(iter_path, (int(epoch)+1, epoch_iter), delimiter=',', fmt='%d')
iter_end_time = time.time()
- print('End of epoch %d / %d \t Time Taken: %d sec' %
- (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
-
- ### save model for this epoch
+ print('End of epoch %d \t Time Taken: %d sec' % (epoch, time.time() - epoch_start_time))
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
+ model.module.save(str(epoch))
+ np.savetxt(iter_path, (int(epoch)+1, 0), delimiter=',', fmt='%d')
model.module.save('latest')
- model.module.save(epoch)
- np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
- ### instead of only training the local enhancer, train the entire network after certain iterations
- # if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
- # model.module.update_fixed_params()
- ### linearly decay learning rate after certain iterations
- # if opt.niter != 0 and epoch > opt.niter:
- # model.module.update_learning_rate()
+# ### instead of only training the local enhancer, train the entire network after certain iterations
+# # if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
+# # model.module.update_fixed_params()
+#
+# ### linearly decay learning rate after certain iterations
+# # if opt.niter != 0 and epoch > opt.niter:
+# # model.module.update_learning_rate()
+