### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). import os from collections import OrderedDict from options.test_options import TestOptions from options.dataset_options import DatasetOptions from data.data_loader import CreateDataLoader from models.models import create_model import util.util as util from util.visualizer import Visualizer from util import html import torch from run_engine import run_trt_engine, run_onnx from datetime import datetime from PIL import Image, ImageOps from shutil import copyfile, rmtree opt = TestOptions().parse(save=False) data_opt = DatasetOptions().parse(opt.unknown) opt.nThreads = 1 # test code only supports nThreads = 1 opt.batchSize = 1 # test code only supports batchSize = 1 opt.serial_batches = True # no shuffle opt.no_flip = True # no flip if data_opt.tag == '': d = datetime.now() tag = data_opt.tag = "{}_{}".format( opt.name, # opt.experiment, d.strftime('%Y%m%d%H%M') ) else: tag = data_opt.tag opt.render_dir = render_dir = opt.results_dir + opt.name + "/" + tag + "/" print('tag:', tag) print('render_dir:', render_dir) util.mkdir(render_dir) data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() start_img_path = os.path.join(render_dir, "frame_00000.png") copyfile(opt.start_img, start_img_path) if not opt.engine and not opt.onnx: model = create_model(opt) if opt.data_type == 16: model.half() elif opt.data_type == 8: model.type(torch.uint8) if opt.verbose: print(model) for i, data in enumerate(dataset): if i >= opt.how_many: break if opt.data_type == 16: data['label'] = data['label'].half() data['inst'] = data['inst'].half() elif opt.data_type == 8: data['label'] = data['label'].uint8() data['inst'] = data['inst'].uint8() minibatch = 1 print(data['label']) print(data['inst']) generated = model.inference(data['label'], data['inst']) last_path = opt.render_dir + "frame_{:05d}.png".format(i) tmp_path = opt.render_dir + "frame_{:05d}_tmp.png".format(i+1) next_path = opt.render_dir + "frame_{:05d}.png".format(i+1) current_path = opt.render_dir + "ren_{:05d}.png".format(i+1) print('process image... %s' % last_path) im = util.tensor2im(generated.data[0]) image_pil = Image.fromarray(im, mode='RGB') image_pil.save(tmp_path) os.rename(tmp_path, next_path)