diff options
| -rwxr-xr-x | .gitignore | 5 | ||||
| -rwxr-xr-x | augment.py | 26 | ||||
| -rwxr-xr-x | augment.sh | 10 | ||||
| -rw-r--r-- | recursive.py | 2 |
4 files changed, 37 insertions, 6 deletions
@@ -43,3 +43,8 @@ test/data/legacy_serialized.pt *.DS_Store *~ +recursive +renders +results +sequences + @@ -86,7 +86,7 @@ if not opt.engine and not opt.onnx: print(model) sequence = read_sequence(data_opt.sequence_name, '') -print("Got sequence {}, {} images".format(data_opt.sequence, len(sequence))) +print("Got sequence {}, {} images".format(data_opt.sequence_name, len(sequence))) _len = len(sequence) - data_opt.augment_take if _len <= 0: @@ -107,16 +107,22 @@ for m in range(data_opt.augment_take): index = i + n if n == 0: A_path = sequence[i] + if opt.verbose: + print(A_path) A = Image.open(A_path) A_tensor = transform(A.convert('RGB')) else: + if opt.verbose: + print(A_path) A_path = os.path.join(opt.render_dir, "recur_{:05d}_{:05d}.png".format(m, index)) A = Image.open(A_path) A_tensor = transform(A.convert('RGB')) - B_path = sequence[index+1] - inst_tensor = 0 - input_dict = {'label': A_tensor, 'inst': inst_tensor} + inst_tensor = torch.LongTensor([0]) + if opt.verbose: + print(A_tensor, inst_tensor) + + data = {'label': A_tensor.unsqueeze(0), 'inst': inst_tensor} if opt.data_type == 16: data['label'] = data['label'].half() @@ -136,5 +142,13 @@ for m in range(data_opt.augment_take): image_pil.save(tmp_path) os.rename(tmp_path, next_path) - os.symlink(next_path, os.path.join("./datasets/", data_opt.sequence, "train_A", "recur_{:05d}_{:05d}.png".format(m, index+1))) - os.symlink(sequence[i+1], os.path.join("./datasets/", data_opt.sequence, "train_B", "recur_{:05d}_{:05d}.png".format(m, index+1))) + frame_A = os.path.join("./datasets/", data_opt.sequence_name, "train_A", "recur_{:05d}_{:05d}.png".format(m, index+1)) + frame_B = os.path.join("./datasets/", data_opt.sequence_name, "train_B", "recur_{:05d}_{:05d}.png".format(m, index+1)) + if os.path.exists(frame_A): + os.unlink(frame_A) + if os.path.exists(frame_B): + os.unlink(frame_B) + os.symlink(os.path.abspath(next_path), os.path.abspath(frame_A)) + os.symlink(os.path.abspath(sequence[i+1]), os.path.abspath(frame_B)) + + diff --git a/augment.sh b/augment.sh new file mode 100755 index 0000000..47f7e74 --- /dev/null +++ b/augment.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +if [ "$1" == "" ]; then + echo "Usage: $0 [dataset] [take] [make]" + exit 1 +fi +python augment.py --dataroot "./datasets/${1}/" --name "$1" --label_nc 0 --no_instance --verbose \ + --results_dir "./recursive" --module_name pix2pixhd --model pix2pixHD --sequence-name "$1" \ + --augment-take "$2" --augment-make "$3" --phase recursive + diff --git a/recursive.py b/recursive.py index 5ab6ff7..bacc3ce 100644 --- a/recursive.py +++ b/recursive.py @@ -62,6 +62,8 @@ for i, data in enumerate(dataset): data['label'] = data['label'].uint8() data['inst'] = data['inst'].uint8() minibatch = 1 + print(data['label']) + print(data['inst']) generated = model.inference(data['label'], data['inst']) last_path = opt.render_dir + "frame_{:05d}.png".format(i) |
