import os import sys sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../live-cortex/rpc/')) from options.test_options import TestOptions from options.dataset_options import DatasetOptions from data import CreateRecursiveDataLoader from models import create_model # from util.visualizer import Visualizer from util.util import mkdirs, crop_image from util import html from shutil import move, copyfile from PIL import Image, ImageOps from skimage.transform import resize from scipy.misc import imresize from shutil import copyfile, rmtree import numpy as np import cv2 from datetime import datetime import re import sys import math import subprocess import glob from time import sleep from rpc import CortexRPC module_name = 'pix2pix' def clamp(n,a,b): return max(a, min(n, b)) def lerp(n,a,b): return (b-a)*n+a def load_opt(): opt_parser = TestOptions() opt = opt_parser.parse() data_opt_parser = DatasetOptions() data_opt = data_opt_parser.parse(opt.unknown) global module_name module_name = opt.module_name opt.nThreads = 1 # test code only supports nThreads = 1 opt.batchSize = 1 # test code only supports batchSize = 1 opt.serial_batches = True # no shuffle opt.no_flip = True # no flip data_opt.tag = get_tag(opt, data_opt) opt.render_dir = opt.results_dir + opt.name + "/" + data_opt.tag + "/" return opt, data_opt, data_opt_parser def get_tag(opt, data_opt): if data_opt.tag == '': d = datetime.now() tag = data_opt.tag = "{}_{}_{}".format( opt.name, 'live', d.strftime('%Y%m%d%H%M') ) else: tag = data_opt.tag return tag def create_render_dir(opt): print("create render_dir: {}".format(opt.render_dir)) if os.path.exists(opt.render_dir): rmtree(opt.render_dir) mkdirs(opt.render_dir) def load_first_frame(opt, data_opt, i=0): start_img_path = os.path.join(opt.render_dir, "frame_{:05}.png".format(i)) if data_opt.just_copy: copyfile(opt.start_img, start_img_path) A_img = None A_im = None A_offset = 0 else: print("preload {}".format(opt.start_img)) A_img = Image.open(opt.start_img).convert('RGB') A_im = np.asarray(A_img) A = process_image(opt, data_opt, A_im) cv2.imwrite(start_img_path, A) numz = re.findall(r'\d+', os.path.basename(opt.start_img)) # print(numz) if len(numz) > 0: A_offset = int(numz[0]) # print(A_offset) if A_offset: print(">> starting offset: {}".format(A_offset)) A_dir = opt.start_img.replace(numz[0], "{:05d}") print(A_dir) else: print("Sequence not found") return A_offset, A_im, A_dir def process_image(opt, data_opt, im): img = im[:, :, ::-1].copy() processed = False if data_opt.clahe is True: processed = True lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=data_opt.clip_limit, tileGridSize=(8,8)) l = clahe.apply(l) limg = cv2.merge((l,a,b)) img = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR) if data_opt.posterize is True: processed = True img = cv2.pyrMeanShiftFiltering(img, data_opt.spatial_window, data_opt.color_window) if data_opt.grayscale is True: processed = True img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if data_opt.blur is True: processed = True img = cv2.GaussianBlur(img, (data_opt.blur_radius, data_opt.blur_radius), data_opt.blur_sigma) if data_opt.canny is True: processed = True img = cv2.Canny(img, data_opt.canny_lo, data_opt.canny_hi) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if processed is False or data_opt.process_frac == 0: return img src_img = im[:, :, ::-1].copy() frac_a = data_opt.process_frac frac_b = 1.0 - frac_a array_a = np.multiply(src_img.astype('float64'), frac_a) array_b = np.multiply(img.astype('float64'), frac_b) img = np.add(array_a, array_b).astype('uint8') return img def list_checkpoints(payload): print("> list checkpoints") return sorted([f.split('/')[3] for f in glob.glob('./checkpoints/' + payload + '/*/latest_net_G.pth')]) def list_epochs(path): print("> list epochs for {}".format(path)) if not os.path.exists(os.path.join('./checkpoints/', path)): return "not found" return sorted([f.split('/')[4].replace('_net_G.pth','') for f in glob.glob('./checkpoints/' + path + '/*_net_G.pth')]) def list_sequences(module): print("> list sequences") sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module)) if os.path.isdir(os.path.join('./sequences/', module, name))]) results = [] for path in sequences: count = len([name for name in os.listdir(os.path.join('./sequences/', module, path)) if os.path.isfile(os.path.join('./sequences/', module, path, name))]) results.append({ 'name': path, 'count': count, }) return results def read_sequence(path): print("> read sequence {}".format(path)) return sorted([f for f in glob.glob(os.path.join('./sequences/', module_name, path, '*.png'))]) class Listener(): def __init__(self): opt, data_opt, data_opt_parser = load_opt() self.opt = opt self.data_opt = data_opt self.data_opt_parser = data_opt_parser.parser self.model = create_model(opt) self.data_opt.load_checkpoint = True self.working = False def _set_fn(self, key, value): if hasattr(self.data_opt, key): try: if str(value) == 'True': setattr(self.data_opt, key, True) print('set {} {}: {}'.format('bool', key, True)) elif str(value) == 'False': setattr(self.data_opt, key, False) print('set {} {}: {}'.format('bool', key, False)) else: new_opt, misc = self.data_opt_parser.parse_known_args([ '--' + key.replace('_', '-'), str(value) ]) new_value = getattr(new_opt, key) setattr(self.data_opt, key, new_value) print('set {} {}: {}'.format(type(new_value), key, new_value)) except Exception as e: print('error {} - cant set value {}: {}'.format(e, key, value)) def _get_fn(self): return vars(self.data_opt) def _cmd_fn(self, cmd, payload): print("got command {}".format(cmd)) if cmd == 'list_checkpoints': return list_checkpoints(payload) if cmd == 'list_epochs': return list_epochs(payload) if cmd == 'list_sequences': return list_sequences(payload) if cmd == 'load_epoch': name, epoch = payload.split(':') print(">>> loading checkpoint {}, epoch {}".format(name, epoch)) self.data_opt.checkpoint_name = name self.data_opt.epoch = epoch self.data_opt.load_checkpoint = True return 'ok' if cmd == 'load_sequence' and os.path.exists('./sequences/' + payload): print('load sequence: {}'.format(payload)) self.data_opt.sequence_name = payload self.data_opt.load_sequence = True if cmd == 'seek': self.data_opt.seek_to = payload if cmd == 'get_status': return { 'processing': self.data_opt.processing, 'checkpoint': self.data_opt.checkpoint_name, 'epoch': self.data_opt.epoch, 'sequence': self.data_opt.sequence_name, } if cmd == 'play' and self.data_opt.processing is False: self.data_opt.pause = False process_live_input(self.opt, self.data_opt, self.rpc_client, self.model) if cmd == 'pause' and self.data_opt.processing is True: self.data_opt.pause = True return 'paused' if cmd == 'exit': print("Exiting now...!") sys.exit(0) return 'exited' return 'ok' def _ready_fn(self, rpc_client): print("Ready!") self.rpc_client = rpc_client process_live_input(self.opt, self.data_opt, rpc_client, self.model) def connect(self): self.rpc_client = CortexRPC(self._get_fn, self._set_fn, self._ready_fn, self._cmd_fn) def process_live_input(opt, data_opt, rpc_client, model): print(">>> Process live input") if data_opt.processing: print("Already processing...") data_opt.processing = True data_loader = CreateRecursiveDataLoader(opt) dataset = data_loader.load_data() create_render_dir(opt) sequence = read_sequence(data_opt.sequence_name) print("Got sequence {}, {} images".format(data_opt.sequence, len(sequence))) if len(sequence) == 0: print("Got empty sequence...") data_opt.processing = False rpc_client.send_status('processing', False) return print("First image: {}".format(sequence[0])) rpc_client.send_status('processing', True) start_img_path = os.path.join(opt.render_dir, "frame_{:05d}.png".format(0)) copyfile(sequence[0], start_img_path) last_im = None print("generating...") sequence_i = 1 for i, data in enumerate(data_loader): if i >= opt.how_many: break if data_opt.load_checkpoint is True: model.save_dir = os.path.join(opt.checkpoints_dir, opt.module_name, data_opt.checkpoint_name) model.load_network(model.netG, 'G', data_opt.epoch) data_opt.load_checkpoint = False if data_opt.load_sequence is True: data_opt.load_sequence = False new_sequence = read_sequence(data_opt.sequence_name) if len(new_sequence) != 0: print("Got sequence {}, {} images, first: {}".format(data_opt.sequence_name, len(sequence), sequence[0])) sequence = new_sequence sequence_i = 1 if data_opt.seek_to != 1: if data_opt.seek_to > 0 and data_opt.seek_to < len(sequence): sequence_i = data_opt.seek_to data_opt.seek_to = 1 model.set_input(data) model.test() visuals = model.get_current_visuals() img_path = model.get_image_paths() if (i % 100) == 0: print('%04d: process image...' % (i)) im = visuals['fake_B'] last_path = opt.render_dir + "frame_{:05d}.png".format(i) tmp_path = opt.render_dir + "frame_{:05d}_tmp.png".format(i+1) next_path = opt.render_dir + "frame_{:05d}.png".format(i+1) current_path = opt.render_dir + "ren_{:05d}.png".format(i+1) meta = { 'i': i, 'sequence_i': sequence_i, 'sequence_len': len(sequence) } if data_opt.sequence and len(sequence): sequence_path = sequence[sequence_i] if sequence_i >= len(sequence)-1: print('(((( sequence looped ))))') sequence_i = 1 else: sequence_i += 1 if data_opt.send_image == 'b': image_pil = Image.fromarray(im, mode='RGB') rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, image_pil) if data_opt.store_a is not True: os.remove(last_path) if data_opt.store_b is True: image_pil = Image.fromarray(im, mode='RGB') image_pil.save(tmp_path) os.rename(tmp_path, current_path) if data_opt.recursive and last_im is not None: if data_opt.sequence and len(sequence): A_img = Image.open(sequence_path).convert('RGB') A_im = np.asarray(A_img) frac_a = data_opt.recursive_frac frac_b = data_opt.sequence_frac frac_sum = frac_a + frac_b if frac_sum > 1.0: frac_a = frac_a / frac_sum frac_b = frac_b / frac_sum if data_opt.transition: t = lerp(math.sin(i / data_opt.transition_period * math.pi * 2.0 ) / 2.0 + 0.5, data_opt.transition_min, data_opt.transition_max) frac_a *= 1.0 - t frac_b *= 1.0 - t frac_c = 1.0 - frac_a - frac_b array_a = np.multiply(last_im.astype('float64'), frac_a) array_b = np.multiply(A_im.astype('float64'), frac_b) array_c = np.multiply(im.astype('float64'), frac_c) array_ab = np.add(array_a, array_b) array_abc = np.add(array_ab, array_c) next_im = array_abc.astype('uint8') else: frac_a = data_opt.recursive_frac frac_b = 1.0 - frac_a array_a = np.multiply(last_im.astype('float64'), frac_a) array_b = np.multiply(im.astype('float64'), frac_b) next_im = np.add(array_a, array_b).astype('uint8') if data_opt.recurse_roll != 0: last_im = np.roll(im, data_opt.recurse_roll, axis=data_opt.recurse_roll_axis) else: last_im = next_im.copy().astype('uint8') else: last_im = im.copy().astype('uint8') next_im = im next_img = process_image(opt, data_opt, next_im) if data_opt.send_image == 'sequence': rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, A_img) if data_opt.send_image == 'recursive': pil_im = Image.fromarray(next_im) rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, pil_im) if data_opt.send_image == 'a': rgb_im = cv2.cvtColor(next_img, cv2.COLOR_BGR2RGB) pil_im = Image.fromarray(rgb_im) rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, pil_im) cv2.imwrite(tmp_path, next_img) os.rename(tmp_path, next_path) print("created {}".format(next_path)) if data_opt.pause: data_opt.pause = False break data_opt.processing = False rpc_client.send_status('processing', False) if __name__ == '__main__': listener = Listener() listener.connect()