import os import sys sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../live-cortex/rpc/')) from collections import OrderedDict from options.test_options import TestOptions from options.dataset_options import DatasetOptions from data.data_loader import CreateDataLoader from models.models import create_model import util.util as util from util.visualizer import Visualizer from util import html import torch from run_engine import run_trt_engine, run_onnx from datetime import datetime from PIL import Image, ImageOps from skimage.transform import resize from scipy.misc import imresize import numpy as np import cv2 import math import subprocess import glob import gevent from time import sleep from shutil import copyfile, rmtree from img_ops import read_sequence, process_image, mix_next_image from listener import Listener module_name = 'pix2pixhd' opt = TestOptions().parse(save=False) data_opt_parser = DatasetOptions() data_opt = data_opt_parser.parse(opt.unknown) data_opt.resize_before_sending = True opt.nThreads = 1 # test code only supports nThreads = 1 opt.batchSize = 1 # test code only supports batchSize = 1 opt.serial_batches = True # no shuffle opt.no_flip = True # no flip if data_opt.tag == '': d = datetime.now() tag = data_opt.tag = "{}_{}".format( opt.name, # opt.experiment, d.strftime('%Y%m%d%H%M') ) else: tag = data_opt.tag opt.render_dir = render_dir = opt.results_dir + "/" + tag + "/" print('tag:', tag) print('render_dir:', render_dir) util.mkdir(render_dir) def process_live_input(opt, data_opt, rpc_client): print(">>> Process live HD input") if data_opt.processing: print("Already processing...") data_opt.processing = True data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() # create_render_dir(opt) sequence = read_sequence(data_opt.sequence_name, '') print("Got sequence {}, {} images".format(data_opt.sequence, len(sequence))) if len(sequence) == 0: print("Got empty sequence...") data_opt.processing = False rpc_client.send_status('processing', False) return print("First image: {}".format(sequence[0])) rpc_client.send_status('processing', True) start_img_path = os.path.join(opt.render_dir, "frame_{:05d}.png".format(0)) copyfile(sequence[0], start_img_path) if not opt.engine and not opt.onnx: model = create_model(opt) if opt.data_type == 16: model.half() elif opt.data_type == 8: model.type(torch.uint8) if opt.verbose: print(model) sequence_i = 1 skip_i = 0 print("generating...") for i, data in enumerate(dataset): if i >= opt.how_many: print("generated {} images, exiting".format(i)) break if data_opt.load_checkpoint is True: checkpoint_fn = "{}_net_{}.pth".format(data_opt.epoch, 'G') checkpoint_path = os.path.join(opt.checkpoints_dir, '', data_opt.checkpoint_name) checkpoint_fn_path = os.path.join(checkpoint_path, checkpoint_fn) if os.path.exists(checkpoint_fn_path): print("Load checkpoint: {}".format(checkpoint_fn_path)) model.load_network(model.netG, 'G', data_opt.epoch, checkpoint_path) else: print("Checkpoint not found: {}".format(checkpoint_fn_path)) data_opt.load_checkpoint = False if data_opt.load_sequence is True: data_opt.load_sequence = False new_sequence = read_sequence(data_opt.sequence_name, '') if len(new_sequence) != 0: print("Got sequence {}, {} images, first: {}".format(data_opt.sequence_name, len(new_sequence), new_sequence[0])) sequence = new_sequence sequence_i = 1 else: print("Sequence not found") if data_opt.seek_to != 1: if data_opt.seek_to > 0 and data_opt.seek_to < len(sequence): sequence_i = data_opt.seek_to data_opt.seek_to = 1 if opt.data_type == 16: data['label'] = data['label'].half() data['inst'] = data['inst'].half() elif opt.data_type == 8: data['label'] = data['label'].uint8() data['inst'] = data['inst'].uint8() minibatch = 1 generated = model.inference(data['label'], data['inst']) im = util.tensor2im(generated.data[0]) sequence_i, skip_i = mix_next_image(opt, data_opt, rpc_client, im, sequence, i=i, sequence_i=sequence_i, skip_i=skip_i) if data_opt.pause: data_opt.pause = False break if data_opt.frame_delay > 0: gevent.sleep(data_opt.frame_delay) data_opt.processing = False rpc_client.send_status('processing', False) if data_opt.tag != data_opt.final_tag and len(data_opt.final_tag) > 0: tag = data_opt.final_tag.lower().replace(' ', '_').replace('__', '_') print("final result: {}".format(tag)) os.rename(opt.render_dir, opt.results_dir + "/" + tag + "/") sleep(0.1) else: print("final result: {}".format(data_opt.tag)) print("done") if __name__ == '__main__': listener = Listener(opt, data_opt, data_opt_parser, process_live_input) listener.connect()