import os import sys import glob import h5py import numpy as np import params import tensorflow as tf import tensorflow_probability as tfp import tensorflow_hub as hub import time import visualize as vs tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../live-cortex/rpc/')) from rpc import CortexRPC from params import Params params = Params('params_dense.json') # -------------------------- # Make directories. # -------------------------- tag = "test" OUTPUT_DIR = os.path.join('output', tag) if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR) # -------------------------- # Load Graph. # -------------------------- sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) generator = hub.Module(str(params.generator_path)) gen_signature = 'generator' if 'generator' not in generator.get_signature_names(): gen_signature = 'default' input_info = generator.get_input_info_dict(gen_signature) COND_GAN = True BATCH_SIZE = 1 Z_DIM = input_info['z'].get_shape().as_list()[1] N_CLASS = input_info['y'].get_shape().as_list()[1] # -------------------------- # Initializers # -------------------------- def label_sampler(shape=[BATCH_SIZE, N_CLASS]): num_classes = 2 label = np.zeros(shape) for i in range(shape[0]): for _ in range(random.randint(1, shape[1])): j = random.randint(0, shape[1]-1) label[i, j] = random.random() label[i] /= label[i].sum() return label # -------------------------- # More complex ops # -------------------------- def sin(opts, key, shape): noise = lerp(opts, key + '_noise', shape) scale = InterpolatorParam(name=key + '_scale') time = opts['time'].variable out = tf.sin(time + noise) * scale.variable opts[key] = { 'scale': scale, } return out def lerp(opts, key, shape): a = InterpolatorParam(name=key + '_a', shape=shape) b = InterpolatorParam(name=key + '_b', shape=shape) n = InterpolatorParam(name=key + '_n') out = a.variable * (1 - n.variable) + b.variable * n.variable opts[key] = { 'a': a, 'b': b, 'n': n, } return out class InterpolatorParam: def __init__(self, name, dtype=tf.float32, shape=(), value=None, type="noise"): self.scalar = shape == () self.shape = shape self.type = type self.value = value or np.zeros(shape) self.variable = tf.Variable(self.value, name=name, dtype=dtype, shape=shape) def assign(value): self.value = value return self.variable.assign(value) def randomize(self, num_classes): if self.type == 'noise': val = np.random.normal(size=self.shape) elif self.type == 'label': val = label_sampler(shape=self.shape) self.assign(val) class Interpolator: def __init__(self): self.opts = { 'global': { 'time': InterpolatorParam(name='t', value=time.time()) }, } def build(self): lerp_z = lerp(self.opts, 'latent', [BATCH_SIZE, Z_DIM]) sin_z = sin(self.opts, 'sin_z', [BATCH_SIZE, Z_DIM]) lerp_label = lerp(self.opts, 'label', [BATCH_SIZE, N_CLASS]) self.opts['truncation'] = InterpolatorParam('truncation', value=1.0) gen_in = {} gen_in['truncation'] = self.opts['truncation'].variable gen_in['z'] = lerp_z + sin_z gen_in['y'] = lerp_label gen_img = generator(gen_in, signature=gen_signature) # Convert generated image to channels_first. self.gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) for group, lookup in self.opts.items(): for key, param in group.items(): if param.scalar: param.assign().eval(session=sess) else: param.randomize().eval(session=sess) def get_state(self): opt = {} for key, param in self.opts.items(): if param.scalar: opt[key] = param.value return opt def set_value(self, key, value): self.opts[key].assign(value) def on_step(self, i): gen_time = time.time() self.opts['global']['time'].assign(gen_time).eval(session=sess) gen_images = sess.run(self.gen_img) print("Generation time: {:.1f}s".format(time.time() - gen_time)) return gen_images def run(self, cmd, payload): # do things like create a new B and interpolate to it pass class Listener: def __init__(self): self.interpolator = Interpolator() self.interpolator.build() def connect(self): self.rpc_client = CortexRPC(self.on_get, self.on_set, self.on_ready, self.on_cmd) def on_set(self, key, value): self.interpolator.set_value(key, value) def on_get(self): return self.interpolator.get_state() def on_cmd(self, cmd, payload): print("got command {}".format(cmd)) self.interpolator.run(cmd, payload) def on_ready(self, rpc_client): print("Ready!") self.rpc_client = rpc_client self.rpc_client.send_status('processing', True) for i in range(99999): gen_images = self.interpolator.on_step(i) if gen_images is None: break out_img = vs.data2pil(gen_images[0]) if out_img is not None: if out_img.resize_before_sending: out_img.resize((256, 256), Image.BICUBIC) self.rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, img_to_send, data_opt.output_format) self.rpc_client.send_status('processing', False) sess.close() if __name__ == '__main__': listener = Listener() listener.connect() # layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer # gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) # ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] # encoding = tf.get_variable(name='encoding', dtype=tf.float32, # shape=[BATCH_SIZE,] + ENC_SHAPE) # tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding))