summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjules@lens <julescarbon@gmail.com>2018-09-05 12:59:08 +0200
committerjules@lens <julescarbon@gmail.com>2018-09-05 12:59:08 +0200
commitc2421b5c8eea3b3b3769ac64119e4445e616e963 (patch)
tree3b997b6450c3202727382d1e12e41d4eae827341
parent3e66eb779b23be9f4cfdc3fc57c523ea2a0fd5c4 (diff)
augment script
-rwxr-xr-x.gitignore5
-rwxr-xr-xaugment.py26
-rwxr-xr-xaugment.sh10
-rw-r--r--recursive.py2
4 files changed, 37 insertions, 6 deletions
diff --git a/.gitignore b/.gitignore
index 5af91f1..16bfab3 100755
--- a/.gitignore
+++ b/.gitignore
@@ -43,3 +43,8 @@ test/data/legacy_serialized.pt
*.DS_Store
*~
+recursive
+renders
+results
+sequences
+
diff --git a/augment.py b/augment.py
index 37e8f23..1d44fe2 100755
--- a/augment.py
+++ b/augment.py
@@ -86,7 +86,7 @@ if not opt.engine and not opt.onnx:
print(model)
sequence = read_sequence(data_opt.sequence_name, '')
-print("Got sequence {}, {} images".format(data_opt.sequence, len(sequence)))
+print("Got sequence {}, {} images".format(data_opt.sequence_name, len(sequence)))
_len = len(sequence) - data_opt.augment_take
if _len <= 0:
@@ -107,16 +107,22 @@ for m in range(data_opt.augment_take):
index = i + n
if n == 0:
A_path = sequence[i]
+ if opt.verbose:
+ print(A_path)
A = Image.open(A_path)
A_tensor = transform(A.convert('RGB'))
else:
+ if opt.verbose:
+ print(A_path)
A_path = os.path.join(opt.render_dir, "recur_{:05d}_{:05d}.png".format(m, index))
A = Image.open(A_path)
A_tensor = transform(A.convert('RGB'))
- B_path = sequence[index+1]
- inst_tensor = 0
- input_dict = {'label': A_tensor, 'inst': inst_tensor}
+ inst_tensor = torch.LongTensor([0])
+ if opt.verbose:
+ print(A_tensor, inst_tensor)
+
+ data = {'label': A_tensor.unsqueeze(0), 'inst': inst_tensor}
if opt.data_type == 16:
data['label'] = data['label'].half()
@@ -136,5 +142,13 @@ for m in range(data_opt.augment_take):
image_pil.save(tmp_path)
os.rename(tmp_path, next_path)
- os.symlink(next_path, os.path.join("./datasets/", data_opt.sequence, "train_A", "recur_{:05d}_{:05d}.png".format(m, index+1)))
- os.symlink(sequence[i+1], os.path.join("./datasets/", data_opt.sequence, "train_B", "recur_{:05d}_{:05d}.png".format(m, index+1)))
+ frame_A = os.path.join("./datasets/", data_opt.sequence_name, "train_A", "recur_{:05d}_{:05d}.png".format(m, index+1))
+ frame_B = os.path.join("./datasets/", data_opt.sequence_name, "train_B", "recur_{:05d}_{:05d}.png".format(m, index+1))
+ if os.path.exists(frame_A):
+ os.unlink(frame_A)
+ if os.path.exists(frame_B):
+ os.unlink(frame_B)
+ os.symlink(os.path.abspath(next_path), os.path.abspath(frame_A))
+ os.symlink(os.path.abspath(sequence[i+1]), os.path.abspath(frame_B))
+
+
diff --git a/augment.sh b/augment.sh
new file mode 100755
index 0000000..47f7e74
--- /dev/null
+++ b/augment.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+if [ "$1" == "" ]; then
+ echo "Usage: $0 [dataset] [take] [make]"
+ exit 1
+fi
+python augment.py --dataroot "./datasets/${1}/" --name "$1" --label_nc 0 --no_instance --verbose \
+ --results_dir "./recursive" --module_name pix2pixhd --model pix2pixHD --sequence-name "$1" \
+ --augment-take "$2" --augment-make "$3" --phase recursive
+
diff --git a/recursive.py b/recursive.py
index 5ab6ff7..bacc3ce 100644
--- a/recursive.py
+++ b/recursive.py
@@ -62,6 +62,8 @@ for i, data in enumerate(dataset):
data['label'] = data['label'].uint8()
data['inst'] = data['inst'].uint8()
minibatch = 1
+ print(data['label'])
+ print(data['inst'])
generated = model.inference(data['label'], data['inst'])
last_path = opt.render_dir + "frame_{:05d}.png".format(i)