import os import sys import glob import h5py import numpy as np import params import json import tensorflow as tf import tensorflow_probability as tfp import tensorflow_hub as hub import time import random from scipy.stats import truncnorm from PIL import Image import visualize as vs tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../live-cortex/rpc/')) from rpc import CortexRPC from params import Params FPS = 25 params = Params(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'params_dense-512.json')) # -------------------------- # Make directories. # -------------------------- tag = "test" OUTPUT_DIR = os.path.join('output', tag) if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR) # -------------------------- # Load Graph. # -------------------------- print("Loading module...") generator = hub.Module(str(params.generator_path)) print("Loaded!") gen_signature = 'generator' if 'generator' not in generator.get_signature_names(): gen_signature = 'default' input_info = generator.get_input_info_dict(gen_signature) BATCH_SIZE = 1 Z_DIM = input_info['z'].get_shape().as_list()[1] N_CLASS = input_info['y'].get_shape().as_list()[1] # -------------------------- # Initializers # -------------------------- def label_sampler(num_classes=2, shape=(BATCH_SIZE, N_CLASS,)): label = np.zeros(shape) for i in range(shape[0]): for _ in range(random.randint(1, shape[1])): j = random.randint(0, shape[1]-1) label[i, j] = random.random() label[i] /= label[i].sum() return label def truncated_z_sample(shape=(BATCH_SIZE, Z_DIM,), truncation=1.0): values = truncnorm.rvs(-2, 2, size=shape) return truncation * values def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)): return np.random.normal(size=shape) # -------------------------- # More complex ops # -------------------------- def sin(opts, key, shape): noise = lerp(opts, key + '_noise', shape) scale = InterpolatorParam(name=key + '_scale', value=0.1) speed = InterpolatorParam(name=key + '_speed', value=1.0) time = opts['time'].variable out = tf.sin(time + noise) * scale.variable opts[key + '_scale'] = scale opts[key + '_speed'] = speed return out def lerp(opts, key, shape): a = InterpolatorParam(name=key + '_a', shape=shape) b = InterpolatorParam(name=key + '_b', shape=shape) n = InterpolatorParam(name=key + '_n', value=0.0) speed = InterpolatorParam(name=key + '_speed', value=0.1) out = a.variable * (1 - n.variable) + b.variable * n.variable opts[key + '_a'] = a opts[key + '_b'] = b opts[key + '_n'] = n opts[key + '_speed'] = speed return out # -------------------------- # Placeholder params # -------------------------- class InterpolatorParam: def __init__(self, name, dtype=tf.float32, shape=(), value=None, type="noise"): self.scalar = shape == () self.shape = shape self.type = type self.value = value if value is not None else np.zeros(shape) self.variable = tf.placeholder(dtype=dtype, shape=shape) def assign(self, value): self.value = value def randomize(self, num_classes): if self.type == 'noise': val = truncated_z_sample(shape=self.shape) elif self.type == 'label': val = label_sampler(shape=self.shape) self.assign(val) # -------------------------- # Interpolator graph # -------------------------- class Interpolator: def __init__(self): self.opts = { 'time': InterpolatorParam(name='t', value=time.time()), 'truncation' : InterpolatorParam(name='truncation', value=1.0), } def build(self): lerp_z = lerp(self.opts, 'latent', [BATCH_SIZE, Z_DIM]) sin_z = sin(self.opts, 'orbit', [BATCH_SIZE, Z_DIM]) lerp_label = lerp(self.opts, 'label', [BATCH_SIZE, N_CLASS]) gen_in = {} gen_in['truncation'] = self.opts['truncation'].variable gen_in['z'] = lerp_z + sin_z gen_in['y'] = lerp_label gen_img = generator(gen_in, signature=gen_signature) # Convert generated image to channels_first. self.gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) def get_feed_dict(self): opt = {} for key, param in self.opts.items(): opt[param.variable] = param.value return opt def get_state(self): opt = {} for key, param in self.opts.items(): if param.scalar: if type(param.value) is np.ndarray: sys.stderr.write('{} is ndarray'.format(key)) opt[key] = param.value.tolist() else: opt[key] = param.value return opt def set_value(self, key, value): self.opts[key].assign(value) def on_step(self, i, sess): gen_images = sess.run(self.gen_img, feed_dict=self.get_feed_dict()) return gen_images def run(self, cmd, payload): # do things like create a new B and interpolate to it pass # -------------------------- # RPC Listener # -------------------------- class Listener: def __init__(self): self.interpolator = Interpolator() self.interpolator.build() def connect(self): self.rpc_client = CortexRPC(self.on_get, self.on_set, self.on_ready, self.on_cmd) def on_set(self, key, value): self.interpolator.set_value(key, value) def on_get(self): state = self.interpolator.get_state() sys.stderr.write(json.dumps(state)) sys.stderr.flush() return state def on_cmd(self, cmd, payload): print("got command {}".format(cmd)) self.interpolator.run(cmd, payload) def on_ready(self, rpc_client): self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.tables_initializer()) print("Ready!") self.rpc_client = rpc_client self.rpc_client.send_status('processing', True) dt = 1 / FPS for i in range(99999): if (i % 100) == 0: print("Step {}".format(i)) gen_time = time.time() self.interpolator.opts['time'].assign(gen_time) gen_images = self.interpolator.on_step(i, self.sess) if gen_images is None: print("Exiting...") break if (i % 100) == 0: print("Generation time: {:.1f}s".format(time.time() - gen_time)) out_img = vs.data2pil(gen_images[0]) if out_img is not None: #if out_img.resize_before_sending: img_to_send = out_img.resize((256, 256), Image.BICUBIC) meta = { 'i': i, 'sequence_i': i, 'skip_i': 0, 'sequence_len': 99999, } self.rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, img_to_send, 'jpg') self.rpc_client.send_status('processing', False) self.sess.close() if __name__ == '__main__': listener = Listener() listener.connect() # layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer # gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) # ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] # encoding = tf.get_variable(name='encoding', dtype=tf.float32, # shape=[BATCH_SIZE,] + ENC_SHAPE) # tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding))