summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xaugment.py4
-rwxr-xr-xtrain.py2
2 files changed, 3 insertions, 3 deletions
diff --git a/augment.py b/augment.py
index 1d44fe2..e34d5cd 100755
--- a/augment.py
+++ b/augment.py
@@ -87,7 +87,7 @@ if not opt.engine and not opt.onnx:
sequence = read_sequence(data_opt.sequence_name, '')
print("Got sequence {}, {} images".format(data_opt.sequence_name, len(sequence)))
-_len = len(sequence) - data_opt.augment_take
+_len = len(sequence) - data_opt.augment_take - 1
if _len <= 0:
print("Got empty sequence...")
@@ -149,6 +149,6 @@ for m in range(data_opt.augment_take):
if os.path.exists(frame_B):
os.unlink(frame_B)
os.symlink(os.path.abspath(next_path), os.path.abspath(frame_A))
- os.symlink(os.path.abspath(sequence[i+1]), os.path.abspath(frame_B))
+ os.symlink(os.path.abspath(sequence[index+1]), os.path.abspath(frame_B))
diff --git a/train.py b/train.py
index 0fe892d..92ef89d 100755
--- a/train.py
+++ b/train.py
@@ -121,5 +121,5 @@ for epoch in range(start_epoch, start_epoch + opt.niter + opt.niter_decay + 1):
model.module.update_fixed_params()
### linearly decay learning rate after certain iterations
- if epoch > opt.niter:
+ if opt.niter != 0 and epoch > opt.niter:
model.module.update_learning_rate()