diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-06-13 16:16:49 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-06-13 16:16:49 +0200 |
| commit | fadf51822689741cd1a7693b44ac4a0f53c975c8 (patch) | |
| tree | 2e395d55f78b3e669f6ece7f197b69d76b2e1aca | |
| parent | 810d391401fba9a62157becdbccfa7188cbc0d16 (diff) | |
recursive dataset stuff
| -rwxr-xr-x | data/custom_dataset_data_loader.py | 11 | ||||
| -rw-r--r-- | data/recursive_dataset.py | 44 | ||||
| -rw-r--r-- | data/sequence_dataset.py | 42 | ||||
| -rw-r--r-- | recursive.py | 45 | ||||
| -rw-r--r-- | recursive.sh | 11 | ||||
| -rwxr-xr-x | train.sh | 2 |
6 files changed, 153 insertions, 2 deletions
diff --git a/data/custom_dataset_data_loader.py b/data/custom_dataset_data_loader.py index 0b98254..89ae554 100755 --- a/data/custom_dataset_data_loader.py +++ b/data/custom_dataset_data_loader.py @@ -4,8 +4,15 @@ from data.base_data_loader import BaseDataLoader def CreateDataset(opt): dataset = None - from data.aligned_dataset import AlignedDataset - dataset = AlignedDataset() + if opt.phase == 'recursive': + from data.recursive_dataset import RecursiveDataset + dataset = RecursiveDataset() + elif opt.phase == 'sequence': + from data.sequence_dataset import SequenceDataset + dataset = SequenceDataset() + else: + from data.aligned_dataset import AlignedDataset + dataset = AlignedDataset() print("dataset [%s] was created" % (dataset.name())) dataset.initialize(opt) diff --git a/data/recursive_dataset.py b/data/recursive_dataset.py new file mode 100644 index 0000000..14ce906 --- /dev/null +++ b/data/recursive_dataset.py @@ -0,0 +1,44 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import os.path +from data.base_dataset import BaseDataset, get_params, get_transform, normalize +from data.image_folder import make_dataset +from PIL import Image + +class RecursiveDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + + ### input A (label maps) + self.dir_A = opt.dataroot + self.A_paths = sorted(make_dataset(self.dir_A)) + + self.dataset_size = len(self.A_paths) + + def __getitem__(self, index): + ### input A (label maps) + A_path = os.path.join(self.opt.dataroot, "frame_{:05d}.png".format(index)) + if not os.path.exists(A_path): + # print() + while not os.path.exists(A_path): + # print('sleeping for {}'.format(self.opt.poll_delay)) + time.sleep(self.opt.poll_delay) + # print("got {}".format(A_path)) + A = Image.open(A_path) + params = get_params(self.opt, A.size) + transform_A = get_transform(self.opt, params) + A_tensor = transform_A(A.convert('RGB')) + + B_tensor = inst_tensor = feat_tensor = 0 + + input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, + 'feat': feat_tensor, 'path': A_path} + + return input_dict + + def __len__(self): + return len(self.A_paths) + + def name(self): + return 'RecursiveDataset'
\ No newline at end of file diff --git a/data/sequence_dataset.py b/data/sequence_dataset.py new file mode 100644 index 0000000..3eaa12b --- /dev/null +++ b/data/sequence_dataset.py @@ -0,0 +1,42 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import os.path +from data.base_dataset import BaseDataset, get_params, get_transform, normalize +from data.image_folder import make_dataset +from PIL import Image + +class SequenceDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + + ### input A (label maps) + self.dir_A = opt.dataroot + self.A_paths = sorted(make_dataset(self.dir_A)) + + self.dataset_size = len(self.A_paths) + + def __getitem__(self, index): + ### input A (label maps) + A_path = self.A_paths[index] + A = Image.open(A_path) + params = get_params(self.opt, A.size) + if self.opt.label_nc == 0: + transform_A = get_transform(self.opt, params) + A_tensor = transform_A(A.convert('RGB')) + else: + transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) + A_tensor = transform_A(A) * 255.0 + + B_tensor = inst_tensor = feat_tensor = 0 + + input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, + 'feat': feat_tensor, 'path': A_path} + + return input_dict + + def __len__(self): + return len(self.A_paths) + + def name(self): + return 'SequenceDataset'
\ No newline at end of file diff --git a/recursive.py b/recursive.py new file mode 100644 index 0000000..dc08b28 --- /dev/null +++ b/recursive.py @@ -0,0 +1,45 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import os +from collections import OrderedDict +from options.test_options import TestOptions +from data.data_loader import CreateDataLoader +from models.models import create_model +import util.util as util +from util.visualizer import Visualizer +from util import html +import torch +from run_engine import run_trt_engine, run_onnx + +opt = TestOptions().parse(save=False) +opt.nThreads = 1 # test code only supports nThreads = 1 +opt.batchSize = 1 # test code only supports batchSize = 1 +opt.serial_batches = True # no shuffle +opt.no_flip = True # no flip + +data_loader = CreateDataLoader(opt) +dataset = data_loader.load_data() +visualizer = Visualizer(opt) +# create website +web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) +webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) + +for i, data in enumerate(dataset): + if i >= opt.how_many: + break + if opt.data_type == 16: + data['label'] = data['label'].half() + data['inst'] = data['inst'].half() + elif opt.data_type == 8: + data['label'] = data['label'].uint8() + data['inst'] = data['inst'].uint8() + minibatch = 1 + generated = model.inference(data['label'], data['inst']) + + visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), + ('synthesized_image', util.tensor2im(generated.data[0]))]) + img_path = data['path'] + print('process image... %s' % img_path) + visualizer.save_images(webpage, visuals, img_path) + +webpage.save() diff --git a/recursive.sh b/recursive.sh new file mode 100644 index 0000000..9bd61b6 --- /dev/null +++ b/recursive.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +if [ "$1" == "" ]; then + echo "Usage: $0 [dataset]" + exit 1 +fi +if [ -e "./checkpoints/${1}" ]; then + python train.py --dataroot "./datasets/${1}/" --name "$1" --label_nc 0 --no_instance --continue_train --which_epoch latest +else + python train.py --dataroot "./datasets/${1}/" --name "$1" --label_nc 0 --no_instance +fi @@ -1,3 +1,5 @@ +#!/bin/bash + if [ "$1" == "" ]; then echo "Usage: $0 [dataset]" exit 1 |
