summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xdata/custom_dataset_data_loader.py11
-rw-r--r--data/recursive_dataset.py44
-rw-r--r--data/sequence_dataset.py42
-rw-r--r--recursive.py45
-rw-r--r--recursive.sh11
-rwxr-xr-xtrain.sh2
6 files changed, 153 insertions, 2 deletions
diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py
index 0b98254..89ae554 100755
--- a/data/custom_dataset_data_loader.py
+++ b/data/custom_dataset_data_loader.py
@@ -4,8 +4,15 @@ from data.base_data_loader import BaseDataLoader
def CreateDataset(opt):
dataset = None
- from data.aligned_dataset import AlignedDataset
- dataset = AlignedDataset()
+ if opt.phase == 'recursive':
+ from data.recursive_dataset import RecursiveDataset
+ dataset = RecursiveDataset()
+ elif opt.phase == 'sequence':
+ from data.sequence_dataset import SequenceDataset
+ dataset = SequenceDataset()
+ else:
+ from data.aligned_dataset import AlignedDataset
+ dataset = AlignedDataset()
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
diff --git a/data/recursive_dataset.py b/data/recursive_dataset.py
new file mode 100644
index 0000000..14ce906
--- /dev/null
+++ b/data/recursive_dataset.py
@@ -0,0 +1,44 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import os.path
+from data.base_dataset import BaseDataset, get_params, get_transform, normalize
+from data.image_folder import make_dataset
+from PIL import Image
+
+class RecursiveDataset(BaseDataset):
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+
+ ### input A (label maps)
+ self.dir_A = opt.dataroot
+ self.A_paths = sorted(make_dataset(self.dir_A))
+
+ self.dataset_size = len(self.A_paths)
+
+ def __getitem__(self, index):
+ ### input A (label maps)
+ A_path = os.path.join(self.opt.dataroot, "frame_{:05d}.png".format(index))
+ if not os.path.exists(A_path):
+ # print()
+ while not os.path.exists(A_path):
+ # print('sleeping for {}'.format(self.opt.poll_delay))
+ time.sleep(self.opt.poll_delay)
+ # print("got {}".format(A_path))
+ A = Image.open(A_path)
+ params = get_params(self.opt, A.size)
+ transform_A = get_transform(self.opt, params)
+ A_tensor = transform_A(A.convert('RGB'))
+
+ B_tensor = inst_tensor = feat_tensor = 0
+
+ input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
+ 'feat': feat_tensor, 'path': A_path}
+
+ return input_dict
+
+ def __len__(self):
+ return len(self.A_paths)
+
+ def name(self):
+ return 'RecursiveDataset' \ No newline at end of file
diff --git a/data/sequence_dataset.py b/data/sequence_dataset.py
new file mode 100644
index 0000000..3eaa12b
--- /dev/null
+++ b/data/sequence_dataset.py
@@ -0,0 +1,42 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import os.path
+from data.base_dataset import BaseDataset, get_params, get_transform, normalize
+from data.image_folder import make_dataset
+from PIL import Image
+
+class SequenceDataset(BaseDataset):
+ def initialize(self, opt):
+ self.opt = opt
+ self.root = opt.dataroot
+
+ ### input A (label maps)
+ self.dir_A = opt.dataroot
+ self.A_paths = sorted(make_dataset(self.dir_A))
+
+ self.dataset_size = len(self.A_paths)
+
+ def __getitem__(self, index):
+ ### input A (label maps)
+ A_path = self.A_paths[index]
+ A = Image.open(A_path)
+ params = get_params(self.opt, A.size)
+ if self.opt.label_nc == 0:
+ transform_A = get_transform(self.opt, params)
+ A_tensor = transform_A(A.convert('RGB'))
+ else:
+ transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
+ A_tensor = transform_A(A) * 255.0
+
+ B_tensor = inst_tensor = feat_tensor = 0
+
+ input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
+ 'feat': feat_tensor, 'path': A_path}
+
+ return input_dict
+
+ def __len__(self):
+ return len(self.A_paths)
+
+ def name(self):
+ return 'SequenceDataset' \ No newline at end of file
diff --git a/recursive.py b/recursive.py
new file mode 100644
index 0000000..dc08b28
--- /dev/null
+++ b/recursive.py
@@ -0,0 +1,45 @@
+### Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
+### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
+import os
+from collections import OrderedDict
+from options.test_options import TestOptions
+from data.data_loader import CreateDataLoader
+from models.models import create_model
+import util.util as util
+from util.visualizer import Visualizer
+from util import html
+import torch
+from run_engine import run_trt_engine, run_onnx
+
+opt = TestOptions().parse(save=False)
+opt.nThreads = 1 # test code only supports nThreads = 1
+opt.batchSize = 1 # test code only supports batchSize = 1
+opt.serial_batches = True # no shuffle
+opt.no_flip = True # no flip
+
+data_loader = CreateDataLoader(opt)
+dataset = data_loader.load_data()
+visualizer = Visualizer(opt)
+# create website
+web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
+webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
+
+for i, data in enumerate(dataset):
+ if i >= opt.how_many:
+ break
+ if opt.data_type == 16:
+ data['label'] = data['label'].half()
+ data['inst'] = data['inst'].half()
+ elif opt.data_type == 8:
+ data['label'] = data['label'].uint8()
+ data['inst'] = data['inst'].uint8()
+ minibatch = 1
+ generated = model.inference(data['label'], data['inst'])
+
+ visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
+ ('synthesized_image', util.tensor2im(generated.data[0]))])
+ img_path = data['path']
+ print('process image... %s' % img_path)
+ visualizer.save_images(webpage, visuals, img_path)
+
+webpage.save()
diff --git a/recursive.sh b/recursive.sh
new file mode 100644
index 0000000..9bd61b6
--- /dev/null
+++ b/recursive.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+if [ "$1" == "" ]; then
+ echo "Usage: $0 [dataset]"
+ exit 1
+fi
+if [ -e "./checkpoints/${1}" ]; then
+ python train.py --dataroot "./datasets/${1}/" --name "$1" --label_nc 0 --no_instance --continue_train --which_epoch latest
+else
+ python train.py --dataroot "./datasets/${1}/" --name "$1" --label_nc 0 --no_instance
+fi
diff --git a/train.sh b/train.sh
index df4969e..9bd61b6 100755
--- a/train.sh
+++ b/train.sh
@@ -1,3 +1,5 @@
+#!/bin/bash
+
if [ "$1" == "" ]; then
echo "Usage: $0 [dataset]"
exit 1