import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import sys import glob import h5py import numpy as np import params import json import tensorflow as tf import tensorflow_probability as tfp import tensorflow_hub as hub import time import random from scipy.stats import truncnorm from PIL import Image import visualize as vs tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../live-cortex/rpc/')) from rpc import CortexRPC from params import Params FPS = 25 params = Params(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'params_dense-512.json')) # -------------------------- # Make directories. # -------------------------- tag = "test" OUTPUT_DIR = os.path.join('output', tag) if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR) # -------------------------- # Load Graph. # -------------------------- print("Loading module...") generator = hub.Module(str(params.generator_path)) print("Loaded!") gen_signature = 'generator' if 'generator' not in generator.get_signature_names(): gen_signature = 'default' input_info = generator.get_input_info_dict(gen_signature) BATCH_SIZE = 1 Z_DIM = input_info['z'].get_shape().as_list()[1] N_CLASS = input_info['y'].get_shape().as_list()[1] # -------------------------- # Utils # -------------------------- def clamp(n, a=0, b=1): if n < a: return a if n > b: return b return n # -------------------------- # Initializers # -------------------------- def label_sampler(num_classes=1, shape=(BATCH_SIZE, N_CLASS,)): label = np.zeros(shape) for i in range(shape[0]): for _ in range(random.randint(1, shape[1])): j = random.randint(0, shape[1]-1) label[i, j] = random.random() label[i] /= label[i].sum() return label def truncated_z_sample(shape=(BATCH_SIZE, Z_DIM,), truncation=1.0): values = truncnorm.rvs(-2, 2, size=shape) return truncation * values def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)): return np.random.normal(size=shape) # -------------------------- # More complex ops # -------------------------- class SinParam: def __init__(self, key, shape, type): noise = LerpParam(key + '_noise', shape, type=type) orbit_radius = InterpolatorParam(name=key + '_radius', value=0.1) orbit_speed = InterpolatorParam(name=key + '_speed', value=0.1) orbit_time = InterpolatorParam(name=key + '_time', value=0.0) output = tf.math.sin(orbit_time.variable + noise.output) * orbit_radius.variable interpolator.sin_params[key] = self self.key = key self.orbit_speed = orbit_speed self.orbit_time = orbit_time self.output = output def update(self, dt): self.orbit_time.value += self.orbit_speed.value * dt class LerpParam: def __init__(self, key, shape, type="noise"): a = InterpolatorParam(name=key + '_a', shape=shape, type=type) b = InterpolatorParam(name=key + '_b', shape=shape, type=type) n = InterpolatorParam(name=key + '_n', value=0.0) speed = InterpolatorParam(name=key + '_speed', value=0.1) output = a.variable * (1 - n.variable) + b.variable * n.variable interpolator.lerp_params[key] = self self.key = key self.a = a self.b = b self.n = n self.speed = speed self.output = output self.direction = 0 def switch(self): if self.n > 0.5: self.a.randomize() self.direction = -1 else: self.b.randomize() self.direction = 1 def update(self, dt): if self.direction != 0: self.n.value = clamp(self.n.value + self.direction * self.speed.value * dt) if self.n.value == 0 or self.n.value == 1: self.direction = 0 # -------------------------- # Placeholder params # -------------------------- class InterpolatorParam: def __init__(self, name, dtype=tf.float32, shape=(), value=None, type="noise"): self.scalar = shape == () self.shape = shape self.type = type self.value = value if value is not None else np.zeros(shape) self.variable = tf.placeholder(dtype=dtype, shape=shape) interpolator.opts[key] = self def assign(self, value): self.value = value def randomize(self): if self.type == 'noise': val = truncated_z_sample(shape=self.shape, truncation=interpolator.opt['truncation'].value) elif self.type == 'label': val = label_sampler(shape=self.shape, num_classes=interpolator.opt['num_classes'].value) self.assign(val) # -------------------------- # Interpolator graph # -------------------------- class Interpolator: def __init__(self): self.opts = {} self.sin_params = {} self.lerp_params = {} def build(self): self.opts['truncation'] = InterpolatorParam(name='truncation', value=1.0), self.opts['num_classes'] = InterpolatorParam(name='num_classes', value=1.0), lerp_z = LerpParam('latent', [BATCH_SIZE, Z_DIM]) sin_z = SinParam('orbit', [BATCH_SIZE, Z_DIM]) lerp_label = LerpParam('label', [BATCH_SIZE, N_CLASS]) # self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], type='noise') # self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], type='label') gen_in = {} gen_in['truncation'] = 1.0 # self.opts['truncation'].variable gen_in['z'] = lerp_z.output + sin_z.output gen_in['y'] = lerp_label.output self.gen_img = generator(gen_in, signature=gen_signature) sys.stderr.write("Sin params: {}".format(", ".join(self.sin_params.keys()))) sys.stderr.write("Lerp params: {}".format(", ".join(self.lerp_params.keys()))) sys.stderr.write("Opts: {}".format(", ".join(self.opts.keys()))) # Convert generated image to channels_first. # self.gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) def get_feed_dict(self): opt = {} for key, param in self.opts.items(): opt[param.variable] = param.value return opt def get_state(self): opt = {} for key, param in self.opts.items(): if param.scalar: if type(param.value) is np.ndarray: sys.stderr.write('{} is ndarray\n'.format(key)) opt[key] = param.value.tolist() else: opt[key] = param.value return opt def set_value(self, key, value): if key in self.opts: self.opts[key].assign(value) else: sys.stderr.write('{} not a valid option\n'.format(key)) def on_step(self, i, dt, sess): for param in self.sin_params.values(): param.update(dt) for param in self.lerp_params.values(): param.update(dt) gen_images = sess.run(self.gen_img, feed_dict=self.get_feed_dict()) return gen_images def run(self, cmd, payload): if cmd == 'switch' and payload in self.lerp_params: self.lerp_params[payload].switch() pass # -------------------------- # RPC Listener # -------------------------- interpolator = Interpolator() class Listener: def connect(self): self.rpc_client = CortexRPC(self.on_get, self.on_set, self.on_ready, self.on_cmd) def on_set(self, key, value): interpolator.set_value(key, value) def on_get(self): state = interpolator.get_state() sys.stderr.write(json.dumps(state) + "\n") sys.stderr.flush() return state def on_cmd(self, cmd, payload): print("got command {}".format(cmd)) interpolator.run(cmd, payload) def on_ready(self, rpc_client): self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.tables_initializer()) print("Ready!") interpolator.build() self.rpc_client = rpc_client self.rpc_client.send_status('processing', True) dt = 1 / FPS for i in range(99999): if (i % 100) == 0: print("Step {}".format(i)) gen_time = time.time() interpolator.opts['time'].assign(gen_time) gen_images = interpolator.on_step(i, dt, self.sess) if gen_images is None: print("Exiting...") break if (i % 100) == 0: print(gen_images.shape) print("Generation time: {:.1f}s".format(time.time() - gen_time)) out_img = vs.data2pil(gen_images[0]) if out_img is not None: #if out_img.resize_before_sending: img_to_send = out_img.resize((256, 256), Image.BICUBIC) meta = { 'i': i, 'sequence_i': i, 'skip_i': 0, 'sequence_len': 99999, } self.rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, img_to_send, 'jpg') self.rpc_client.send_status('processing', False) self.sess.close() if __name__ == '__main__': listener = Listener() listener.connect() # layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer # gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) # ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] # encoding = tf.get_variable(name='encoding', dtype=tf.float32, # shape=[BATCH_SIZE,] + ENC_SHAPE) # tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding))