### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. ### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). import os import sys sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../live-cortex/rpc/')) from collections import OrderedDict from options.test_options import TestOptions from options.dataset_options import DatasetOptions from data.data_loader import CreateDataLoader from models.models import create_model import util.util as util from util.visualizer import Visualizer from util import html import torch import numpy as np from run_engine import run_trt_engine, run_onnx from datetime import datetime from PIL import Image, ImageOps from shutil import copyfile, rmtree from random import randint from img_ops import read_sequence import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms def get_transform(opt, method=Image.BICUBIC, normalize=True): transform_list = [] base = float(2 ** opt.n_downsample_global) if opt.netG == 'local': base *= (2 ** opt.n_local_enhancers) transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) transform_list += [transforms.ToTensor()] if normalize: transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) def normalize(): return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) def __make_power_2(img, base, method=Image.BICUBIC): ow, oh = img.size h = int(round(oh / base) * base) w = int(round(ow / base) * base) if (h == oh) and (w == ow): return img return img.resize((w, h), method) opt = TestOptions().parse(save=False) data_opt = DatasetOptions().parse(opt.unknown) opt.nThreads = 1 # test code only supports nThreads = 1 opt.batchSize = 1 # test code only supports batchSize = 1 opt.serial_batches = True # no shuffle opt.no_flip = True # no flip if data_opt.tag == '': d = datetime.now() tag = data_opt.tag = "{}_{}".format( opt.name, # opt.experiment, d.strftime('%Y%m%d%H%M') ) else: tag = data_opt.tag if opt.which_epoch == 'latest': iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') if os.path.exists(iter_path): try: current_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) except: current_epoch, epoch_iter = 1, 0 print('Resuming from epoch %d at iteration %d' % (current_epoch, epoch_iter)) else: current_epoch, epoch_iter = 1, 0 else: current_epoch = opt.which_epoch epoch_id = "{:02d}_{:04d}_{:04d}".format(current_epoch, data_opt.augment_take, data_opt.augment_make) opt.render_dir = os.path.join(opt.results_dir, opt.name, epoch_id) if not opt.engine and not opt.onnx: model = create_model(opt) if opt.data_type == 16: model.half() elif opt.data_type == 8: model.type(torch.uint8) if opt.verbose: print(model) sequence = read_sequence(data_opt.sequence_name, '') print("Got sequence {}, {} images".format(data_opt.sequence_name, len(sequence))) _len = len(sequence) - data_opt.augment_take - 2 if _len <= 0: print("Got empty sequence...") data_opt.processing = False rpc_client.send_status('processing', False) sys.exit(1) transform = get_transform(opt) print('tag:', tag) print('render_dir:', opt.render_dir) util.mkdir(opt.render_dir) # add augment name for m in range(data_opt.augment_take): i = randint(0, _len) index = i for n in range(data_opt.augment_make): index = i + n if n == 0: A_path = sequence[index] if opt.verbose: print(A_path) A = Image.open(A_path) A_tensor = transform(A.convert('RGB')) else: if opt.verbose: print(A_path) A_path = os.path.join(opt.render_dir, "recur_{}_{:05d}_{:05d}.png".format(epoch_id, m, index)) A = Image.open(A_path) A_tensor = transform(A.convert('RGB')) inst_tensor = torch.LongTensor([0]) if opt.verbose: print(A_tensor, inst_tensor) data = {'label': A_tensor.unsqueeze(0), 'inst': inst_tensor} if opt.data_type == 16: data['label'] = data['label'].half() data['inst'] = data['inst'].half() elif opt.data_type == 8: data['label'] = data['label'].uint8() data['inst'] = data['inst'].uint8() minibatch = 1 generated = model.inference(data['label'], data['inst']) tmp_path = os.path.join(opt.render_dir, "recur_{}_{:05d}_{:05d}_tmp.png".format(epoch_id, m, index+1)) next_path = os.path.join(opt.render_dir, "recur_{}_{:05d}_{:05d}.png".format(epoch_id, m, index+1)) print('process image... %i' % index) im = util.tensor2im(generated.data[0]) image_pil = Image.fromarray(im, mode='RGB') image_pil.save(tmp_path) os.rename(tmp_path, next_path) frame_A = os.path.join("./datasets/", data_opt.sequence_name, "train_A", "recur_{}_{:05d}_{:05d}.png".format(epoch_id, m, index+1)) frame_B = os.path.join("./datasets/", data_opt.sequence_name, "train_B", "recur_{}_{:05d}_{:05d}.png".format(epoch_id, m, index+1)) if os.path.exists(frame_A): os.unlink(frame_A) if os.path.exists(frame_B): os.unlink(frame_B) os.symlink(os.path.abspath(next_path), os.path.abspath(frame_A)) os.symlink(os.path.abspath(sequence[index+2]), os.path.abspath(frame_B))