summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/train.py b/train.py
index 0834d37..4d80eb6 100644
--- a/train.py
+++ b/train.py
@@ -12,15 +12,15 @@ print('#training images = %d' % dataset_size)
model = create_model(opt)
visualizer = Visualizer(opt)
-
total_steps = 0
-for epoch in range(1, opt.niter + opt.niter_decay + 1):
+for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
+ epoch_iter = 0
for i, data in enumerate(dataset):
iter_start_time = time.time()
total_steps += opt.batchSize
- epoch_iter = total_steps - dataset_size * (epoch - 1)
+ epoch_iter += opt.batchSize
model.set_input(data)
model.optimize_parameters()