summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/train.py b/train.py
index f6072c7..61b596a 100644
--- a/train.py
+++ b/train.py
@@ -16,10 +16,13 @@ total_steps = 0
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
+ iter_data_time = time.time()
epoch_iter = 0
for i, data in enumerate(dataset):
iter_start_time = time.time()
+ if total_steps % opt.print_freq == 0:
+ t_data = iter_start_time - iter_data_time
visualizer.reset()
total_steps += opt.batchSize
epoch_iter += opt.batchSize
@@ -33,7 +36,7 @@ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
if total_steps % opt.print_freq == 0:
errors = model.get_current_errors()
t = (time.time() - iter_start_time) / opt.batchSize
- visualizer.print_current_errors(epoch, epoch_iter, errors, t)
+ visualizer.print_current_errors(epoch, epoch_iter, errors, t, t_data)
if opt.display_id > 0:
visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors)
@@ -42,6 +45,7 @@ for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
(epoch, total_steps))
model.save('latest')
+ iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' %
(epoch, total_steps))