summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-08-26 18:28:29 -0700
committerjunyanz <junyanz@berkeley.edu>2017-08-26 18:28:29 -0700
commit8bc6778456c86a8b9d88362efb195635d2a0dac3 (patch)
tree0a068f150820d3d31ef1609634ba340cefe719ef /train.py
parentf085c5c977af48cfaed3c61b0a8b9c951ad88337 (diff)
add epoch_count
Diffstat (limited to 'train.py')
-rw-r--r--train.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/train.py b/train.py
index 0834d37..4d80eb6 100644
--- a/train.py
+++ b/train.py
@@ -12,15 +12,15 @@ print('#training images = %d' % dataset_size)
model = create_model(opt)
visualizer = Visualizer(opt)
-
total_steps = 0
-for epoch in range(1, opt.niter + opt.niter_decay + 1):
+for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
+ epoch_iter = 0
for i, data in enumerate(dataset):
iter_start_time = time.time()
total_steps += opt.batchSize
- epoch_iter = total_steps - dataset_size * (epoch - 1)
+ epoch_iter += opt.batchSize
model.set_input(data)
model.optimize_parameters()