summaryrefslogtreecommitdiff
path: root/augment.py
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-09-07 14:40:59 +0200
committerJules Laplace <julescarbon@gmail.com>2018-09-07 14:40:59 +0200
commit15efabd225f08921c4ffb83f488710ac7456f9d2 (patch)
tree0bedeee2787017b1f641cd35c59c62d79b21808e /augment.py
parentde7d9fb18fcc8a43e2d365203904514e89dd414e (diff)
augment
Diffstat (limited to 'augment.py')
-rwxr-xr-xaugment.py39
1 files changed, 25 insertions, 14 deletions
diff --git a/augment.py b/augment.py
index e34d5cd..a9ba885 100755
--- a/augment.py
+++ b/augment.py
@@ -12,6 +12,7 @@ import util.util as util
from util.visualizer import Visualizer
from util import html
import torch
+import numpy as np
from run_engine import run_trt_engine, run_onnx
from datetime import datetime
from PIL import Image, ImageOps
@@ -50,7 +51,6 @@ def __make_power_2(img, base, method=Image.BICUBIC):
return img
return img.resize((w, h), method)
-
opt = TestOptions().parse(save=False)
data_opt = DatasetOptions().parse(opt.unknown)
opt.nThreads = 1 # test code only supports nThreads = 1
@@ -67,14 +67,21 @@ if data_opt.tag == '':
else:
tag = data_opt.tag
-opt.render_dir = os.path.join(opt.results_dir, opt.name, opt.which_epoch)
-
-print('tag:', tag)
-print('render_dir:', opt.render_dir)
-util.mkdir(opt.render_dir)
+if opt.current_epoch == 'latest':
+ iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
+ if os.path.exists(iter_path):
+ try:
+ current_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
+ except:
+ current_epoch, epoch_iter = 1, 0
+ print('Resuming from epoch %d at iteration %d' % (current_epoch, epoch_iter))
+ else:
+ current_epoch, epoch_iter = 1, 0
+else:
+ current_epoch = opt.current_epoch
-data_loader = CreateDataLoader(opt)
-dataset = data_loader.load_data()
+epoch_id = "{:02d}_{:04d}_{:04d}".format(current_epoch, data_opt.augment_take, data_opt.augment_make)
+opt.render_dir = os.path.join(opt.results_dir, opt.name, epoch_id)
if not opt.engine and not opt.onnx:
model = create_model(opt)
@@ -97,6 +104,10 @@ if _len <= 0:
transform = get_transform(opt)
+print('tag:', tag)
+print('render_dir:', opt.render_dir)
+util.mkdir(opt.render_dir)
+
# add augment name
for m in range(data_opt.augment_take):
@@ -106,7 +117,7 @@ for m in range(data_opt.augment_take):
for n in range(data_opt.augment_make):
index = i + n
if n == 0:
- A_path = sequence[i]
+ A_path = sequence[index]
if opt.verbose:
print(A_path)
A = Image.open(A_path)
@@ -114,7 +125,7 @@ for m in range(data_opt.augment_take):
else:
if opt.verbose:
print(A_path)
- A_path = os.path.join(opt.render_dir, "recur_{:05d}_{:05d}.png".format(m, index))
+ A_path = os.path.join(opt.render_dir, "recur_{}_{:05d}_{:05d}.png".format(epoch_id, m, index))
A = Image.open(A_path)
A_tensor = transform(A.convert('RGB'))
@@ -133,8 +144,8 @@ for m in range(data_opt.augment_take):
minibatch = 1
generated = model.inference(data['label'], data['inst'])
- tmp_path = os.path.join(opt.render_dir, "recur_{:05d}_{:05d}_tmp.png".format(m, index+1))
- next_path = os.path.join(opt.render_dir, "recur_{:05d}_{:05d}.png".format(m, index+1))
+ tmp_path = os.path.join(opt.render_dir, "recur_{}_{:05d}_{:05d}_tmp.png".format(epoch_id, m, index+1))
+ next_path = os.path.join(opt.render_dir, "recur_{}_{:05d}_{:05d}.png".format(epoch_id, m, index+1))
print('process image... %i' % index)
im = util.tensor2im(generated.data[0])
@@ -142,8 +153,8 @@ for m in range(data_opt.augment_take):
image_pil.save(tmp_path)
os.rename(tmp_path, next_path)
- frame_A = os.path.join("./datasets/", data_opt.sequence_name, "train_A", "recur_{:05d}_{:05d}.png".format(m, index+1))
- frame_B = os.path.join("./datasets/", data_opt.sequence_name, "train_B", "recur_{:05d}_{:05d}.png".format(m, index+1))
+ frame_A = os.path.join("./datasets/", data_opt.sequence_name, "train_A", "recur_{}_{:05d}_{:05d}.png".format(epoch_id, m, index+1))
+ frame_B = os.path.join("./datasets/", data_opt.sequence_name, "train_B", "recur_{}_{:05d}_{:05d}.png".format(epoch_id, m, index+1))
if os.path.exists(frame_A):
os.unlink(frame_A)
if os.path.exists(frame_B):