diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-06-18 15:47:25 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-06-18 15:47:25 +0200 |
| commit | c0813619431a83067da73d75717a6aecd6747a1f (patch) | |
| tree | 4ea61f3cc613438084434de2dad8bd574e26f2ea /live.py | |
| parent | 2b228ed309415418d54e284a8ac430e5161a6535 (diff) | |
live script
Diffstat (limited to 'live.py')
| -rw-r--r-- | live.py | 143 |
1 files changed, 143 insertions, 0 deletions
@@ -0,0 +1,143 @@ +import os +import sys +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../live-cortex/rpc/')) +from collections import OrderedDict +from options.test_options import TestOptions +from options.dataset_options import DatasetOptions +from data.data_loader import CreateDataLoader +from models.models import create_model +import util.util as util +from util.visualizer import Visualizer +from util import html +import torch +from run_engine import run_trt_engine, run_onnx +from datetime import datetime +from PIL import Image, ImageOps +from skimage.transform import resize +from scipy.misc import imresize +import numpy as np +import cv2 +import math +import subprocess +import glob +import gevent +from time import sleep +from shutil import copyfile, rmtree + +from img_ops import process_image, mix_next_image +from listener import Listener + +module_name = 'pix2pixhd' + +opt = TestOptions().parse(save=False) +data_opt_parser = DatasetOptions() +data_opt = data_opt_parser.parse(opt.unknown) +data_opt.resize_before_sending = True +opt.nThreads = 1 # test code only supports nThreads = 1 +opt.batchSize = 1 # test code only supports batchSize = 1 +opt.serial_batches = True # no shuffle +opt.no_flip = True # no flip +if data_opt.tag == '': + d = datetime.now() + tag = data_opt.tag = "{}_{}".format( + opt.name, + # opt.experiment, + d.strftime('%Y%m%d%H%M') + ) +else: + tag = data_opt.tag + +opt.render_dir = render_dir = opt.results_dir + opt.name + "/" + tag + "/" + +print('tag:', tag) +print('render_dir:', render_dir) +util.mkdir(render_dir) + +def process_live_input(opt, data_opt, rpc_client): + print(">>> Process live HD input") + if data_opt.processing: + print("Already processing...") + data_opt.processing = True + data_loader = CreateDataLoader(opt) + dataset = data_loader.load_data() + + create_render_dir(opt) + sequence = read_sequence(data_opt.sequence_name, '') + print("Got sequence {}, {} images".format(data_opt.sequence, len(sequence))) + if len(sequence) == 0: + print("Got empty sequence...") + data_opt.processing = False + rpc_client.send_status('processing', False) + return + print("First image: {}".format(sequence[0])) + + rpc_client.send_status('processing', True) + + start_img_path = os.path.join(opt.render_dir, "frame_{:05d}.png".format(0)) + copyfile(sequence[0], start_img_path) + + if not opt.engine and not opt.onnx: + model = create_model(opt) + if opt.data_type == 16: + model.half() + elif opt.data_type == 8: + model.type(torch.uint8) + if opt.verbose: + print(model) + + sequence_i = 1 + + print("generating...") + for i, data in enumerate(data_loader): + if i >= opt.how_many: + print("generated {} images, exiting".format(i)) + break + + if data_opt.load_checkpoint is True: + checkpoint_fn = "{}_net_{}.pth".format(data_opt.epoch, 'G') + checkpoint_path = os.path.join(opt.checkpoints_dir, '', data_opt.checkpoint_name) + checkpoint_fn_path = os.path.join(checkpoint_path, checkpoint_fn) + if os.path.exists(checkpoint_fn_path): + print("Load checkpoint: {}".format(checkpoint_fn_path)) + model.load_network(model.netG, 'G', data_opt.epoch, checkpoint_path) + else: + print("Checkpoint not found: {}".format(checkpoint_fn_path)) + data_opt.load_checkpoint = False + if data_opt.load_sequence is True: + data_opt.load_sequence = False + new_sequence = read_sequence(data_opt.sequence_name, '') + if len(new_sequence) != 0: + print("Got sequence {}, {} images, first: {}".format(data_opt.sequence_name, len(new_sequence), new_sequence[0])) + sequence = new_sequence + sequence_i = 1 + else: + print("Sequence not found") + if data_opt.seek_to != 1: + if data_opt.seek_to > 0 and data_opt.seek_to < len(sequence): + sequence_i = data_opt.seek_to + data_opt.seek_to = 1 + + if opt.data_type == 16: + data['label'] = data['label'].half() + data['inst'] = data['inst'].half() + elif opt.data_type == 8: + data['label'] = data['label'].uint8() + data['inst'] = data['inst'].uint8() + minibatch = 1 + generated = model.inference(data['label'], data['inst']) + + im = util.tensor2im(generated.data[0]) + + sequence_i = mix_next_image(opt, data_opt, rpc_client, im, i, sequence, sequence_i) + + if data_opt.pause: + data_opt.pause = False + break + gevent.sleep(data_opt.frame_delay) + + data_opt.processing = False + rpc_client.send_status('processing', False) + +if __name__ == '__main__': + listener = Listener(opt, data_opt, data_opt_parser, process_live_input) + listener.connect() |
