From c2421b5c8eea3b3b3769ac64119e4445e616e963 Mon Sep 17 00:00:00 2001 From: "jules@lens" Date: Wed, 5 Sep 2018 12:59:08 +0200 Subject: augment script --- augment.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) (limited to 'augment.py') diff --git a/augment.py b/augment.py index 37e8f23..1d44fe2 100755 --- a/augment.py +++ b/augment.py @@ -86,7 +86,7 @@ if not opt.engine and not opt.onnx: print(model) sequence = read_sequence(data_opt.sequence_name, '') -print("Got sequence {}, {} images".format(data_opt.sequence, len(sequence))) +print("Got sequence {}, {} images".format(data_opt.sequence_name, len(sequence))) _len = len(sequence) - data_opt.augment_take if _len <= 0: @@ -107,16 +107,22 @@ for m in range(data_opt.augment_take): index = i + n if n == 0: A_path = sequence[i] + if opt.verbose: + print(A_path) A = Image.open(A_path) A_tensor = transform(A.convert('RGB')) else: + if opt.verbose: + print(A_path) A_path = os.path.join(opt.render_dir, "recur_{:05d}_{:05d}.png".format(m, index)) A = Image.open(A_path) A_tensor = transform(A.convert('RGB')) - B_path = sequence[index+1] - inst_tensor = 0 - input_dict = {'label': A_tensor, 'inst': inst_tensor} + inst_tensor = torch.LongTensor([0]) + if opt.verbose: + print(A_tensor, inst_tensor) + + data = {'label': A_tensor.unsqueeze(0), 'inst': inst_tensor} if opt.data_type == 16: data['label'] = data['label'].half() @@ -136,5 +142,13 @@ for m in range(data_opt.augment_take): image_pil.save(tmp_path) os.rename(tmp_path, next_path) - os.symlink(next_path, os.path.join("./datasets/", data_opt.sequence, "train_A", "recur_{:05d}_{:05d}.png".format(m, index+1))) - os.symlink(sequence[i+1], os.path.join("./datasets/", data_opt.sequence, "train_B", "recur_{:05d}_{:05d}.png".format(m, index+1))) + frame_A = os.path.join("./datasets/", data_opt.sequence_name, "train_A", "recur_{:05d}_{:05d}.png".format(m, index+1)) + frame_B = os.path.join("./datasets/", data_opt.sequence_name, "train_B", "recur_{:05d}_{:05d}.png".format(m, index+1)) + if os.path.exists(frame_A): + os.unlink(frame_A) + if os.path.exists(frame_B): + os.unlink(frame_B) + os.symlink(os.path.abspath(next_path), os.path.abspath(frame_A)) + os.symlink(os.path.abspath(sequence[i+1]), os.path.abspath(frame_B)) + + -- cgit v1.2.3-70-g09d2