### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). import os import sys sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../live-cortex/rpc/')) from collections import OrderedDict from options.test_options import TestOptions from options.dataset_options import DatasetOptions from data.data_loader import CreateDataLoader from models.models import create_model import util.util as util from util.visualizer import Visualizer from util import html import torch from run_engine import run_trt_engine, run_onnx from datetime import datetime from PIL import Image, ImageOps from shutil import copyfile, rmtree from random import randint from img_ops import read_sequence import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms def get_transform(opt, method=Image.BICUBIC, normalize=True): transform_list = [] base = float(2 ** opt.n_downsample_global) if opt.netG == 'local': base *= (2 ** opt.n_local_enhancers) transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) transform_list += [transforms.ToTensor()] if normalize: transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) def normalize(): return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) def __make_power_2(img, base, method=Image.BICUBIC): ow, oh = img.size h = int(round(oh / base) * base) w = int(round(ow / base) * base) if (h == oh) and (w == ow): return img return img.resize((w, h), method) opt = TestOptions().parse(save=False) data_opt = DatasetOptions().parse(opt.unknown) opt.nThreads = 1 # test code only supports nThreads = 1 opt.batchSize = 1 # test code only supports batchSize = 1 opt.serial_batches = True # no shuffle opt.no_flip = True # no flip if data_opt.tag == '': d = datetime.now() tag = data_opt.tag = "{}_{}".format( opt.name, # opt.experiment, d.strftime('%Y%m%d%H%M') ) else: tag = data_opt.tag opt.render_dir = os.path.join(opt.results_dir, opt.name, opt.which_epoch) print('tag:', tag) print('render_dir:', opt.render_dir) util.mkdir(opt.render_dir) data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() if not opt.engine and not opt.onnx: model = create_model(opt) if opt.data_type == 16: model.half() elif opt.data_type == 8: model.type(torch.uint8) if opt.verbose: print(model) sequence = read_sequence(data_opt.sequence_name, '') print("Got sequence {}, {} images".format(data_opt.sequence, len(sequence))) _len = len(sequence) - data_opt.augment_take if _len <= 0: print("Got empty sequence...") data_opt.processing = False rpc_client.send_status('processing', False) sys.exit(1) transform = get_transform(opt) # add augment name for m in range(data_opt.augment_take): i = randint(0, _len) index = i for n in range(data_opt.augment_make): index = i + n if n == 0: A_path = sequence[i] A = Image.open(A_path) A_tensor = transform(A.convert('RGB')) else: A_path = os.path.join(opt.render_dir, "recur_{:05d}_{:05d}.png".format(m, index)) A = Image.open(A_path) A_tensor = transform(A.convert('RGB')) B_path = sequence[index+1] inst_tensor = 0 input_dict = {'label': A_tensor, 'inst': inst_tensor} if opt.data_type == 16: data['label'] = data['label'].half() data['inst'] = data['inst'].half() elif opt.data_type == 8: data['label'] = data['label'].uint8() data['inst'] = data['inst'].uint8() minibatch = 1 generated = model.inference(data['label'], data['inst']) tmp_path = os.path.join(opt.render_dir, "recur_{:05d}_{:05d}_tmp.png".format(m, index+1)) next_path = os.path.join(opt.render_dir, "recur_{:05d}_{:05d}.png".format(m, index+1)) print('process image... %i' % index) im = util.tensor2im(generated.data[0]) image_pil = Image.fromarray(im, mode='RGB') image_pil.save(tmp_path) os.rename(tmp_path, next_path) os.symlink(next_path, os.path.join("./datasets/", data_opt.sequence, "train_A", "recur_{:05d}_{:05d}.png".format(m, index+1))) os.symlink(sequence[i+1], os.path.join("./datasets/", data_opt.sequence, "train_B", "recur_{:05d}_{:05d}.png".format(m, index+1)))