From ba4d18a2d20eed034ea7926d10e5f760d3809ba2 Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Wed, 18 Dec 2019 23:41:47 +0100 Subject: live --- inversion/listener.py | 7 +- inversion/live.py | 191 ++++++++++++++++++++++++++++++++++--------------- inversion/visualize.py | 3 + 3 files changed, 137 insertions(+), 64 deletions(-) diff --git a/inversion/listener.py b/inversion/listener.py index a43c33c..fc04586 100644 --- a/inversion/listener.py +++ b/inversion/listener.py @@ -4,7 +4,7 @@ sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../.. from rpc import CortexRPC class Listener: - def __init__(self, opt, run_live): + def __init__(self, opt, run_live, run_cmd): self.opt = opt self.run_live = run_live def _set_fn(self, key, value): @@ -13,10 +13,7 @@ class Listener: return self.opt def _cmd_fn(self, cmd, payload): print("got command {}".format(cmd)) - if cmd == '': - pass - else: - pass + run_cmd(cmd, payload) def _ready_fn(self, rpc_client): print("Ready!") self.rpc_client = rpc_client diff --git a/inversion/live.py b/inversion/live.py index 672853a..8ad38a2 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -4,13 +4,14 @@ import glob import h5py import numpy as np import params -import PIL import tensorflow as tf import tensorflow_probability as tfp import tensorflow_hub as hub import time import visualize as vs tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../live-cortex/rpc/')) +from rpc import CortexRPC from listener import Listener from params import Params @@ -29,6 +30,10 @@ if not os.path.exists(OUTPUT_DIR): # -------------------------- # Load Graph. # -------------------------- +sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) +sess.run(tf.global_variables_initializer()) +sess.run(tf.tables_initializer()) + generator = hub.Module(str(params.generator_path)) gen_signature = 'generator' @@ -42,33 +47,134 @@ BATCH_SIZE = 1 Z_DIM = input_info['z'].get_shape().as_list()[1] N_CLASS = input_info['y'].get_shape().as_list()[1] -def sin(key, shape): - # speed should be computed outside of tensorflow - # (so we can recursively update t = last_t + speed) - noise, noise_a, noise_b, noise_n = lerp('sin_noise', shape) - scale = tf.get_variable(name=key + '_scale', dtype=tf.float32, shape=(1,)) - t = tf.get_variable(name=key + '_t', dtype=tf.float32, shape=(1,)) - out = tf.sin(t + noise) * scale - return out, t, scale, noise_a, noise_b, noise_n - -def lerp(key, shape): - a = tf.get_variable(name=key + '_a', dtype=tf.float32, shape=shape) - b = tf.get_variable(name=key + '_b', dtype=tf.float32, shape=shape) - n = tf.get_variable(name=key + '_n', dtype=tf.float32, shape=(1,)) - out = a * (1 - n) + b * n - return out, a, b, n - -lerp_z, z_a, z_b, z_n = lerp('latent', [BATCH_SIZE, Z_DIM]) -sin_z, sin_t, sin_scale, sin_noise_a, sin_noise_b, sin_noise_n = lerp('sin_z', [BATCH_SIZE, Z_DIM]) -lerp_label, label_a, label_b, label_n = lerp('label', [BATCH_SIZE, N_CLASS]) -gen_in = dict(params.generator_fixed_inputs) -gen_in['z'] = lerp_z + sin_z -gen_in['y'] = lerp_label -gen_img = generator(gen_in, signature=gen_signature) +def sin(opts, key, shape): + noise = lerp(opts, key + '_noise', shape) + scale = InterpolatorParam(name=key + '_scale') + time = opts['global']['time'].variable + out = tf.sin(time + noise) * scale + opts[key] = { + 'scale': scale, + } + return out + +def lerp(opts, key, shape): + a = InterpolatorParam(name=key + '_a', shape=shape) + b = InterpolatorParam(name=key + '_b', shape=shape) + n = InterpolatorParam(name=key + '_n') + out = a * (1 - n) + b * n + opts[key] = { + 'a': a, + 'b': b, + 'n': n, + } + return out + +class InterpolatorParam: + def __init__(self, name, dtype=tf.float32, shape=(), value=None): + self.scalar = shape == () + self.shape = shape + self.value = np.zeros(shape) if value == None else + self.variable = tf.Variable(self.value, name=name, dtype=dtype, shape=shape) + + def assign(value): + self.value = value + return self.variable.assign(value) + + def randomize(self): + return self.assign(np.random.normal(size=self.shape)) + +class Interpolator: + def __init__(self): + self.opts = {} + self.t = time.time() + + def build(self): + opts = { + 'global': { + 'time': InterpolatorParam(name='t', value=time.time()) + }, + } + lerp_z = lerp(opts, 'latent', [BATCH_SIZE, Z_DIM]) + sin_z = sin(opts, 'sin_z', [BATCH_SIZE, Z_DIM]) + lerp_label = lerp(opts, 'label', [BATCH_SIZE, N_CLASS]) + opts['threshold'] = InterpolatorParam('threshold', value=1.0) + + gen_in = {} + gen_in['threshold'] = opts['threshold'].variable + gen_in['z'] = lerp_z + sin_z + gen_in['y'] = lerp_label + gen_img = generator(gen_in, signature=gen_signature) + + # Convert generated image to channels_first. + self.gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) + + for group, lookup in self.opts.items(): + for key, param in group.items(): + if param.scalar: + param.assign().eval(session=sess) + else: + param.randomize().eval(session=sess) + + def get_state(self): + opt = {} + for group, lookup in self.opts.items(): + for key, param in group.items(): + if param.scalar: + opt[group][key] = param.value + return opt + + def set_value(self, key, value): + self.opts[key].assign(value).eval(session=sess) + + def on_step(i): + gen_time = time.time() + self.opts['global']['time'].assign(gen_time).eval(session=sess) + gen_images = sess.run(self.gen_img) + print("Generation time: {:.1f}s".format(time.time() - gen_time)) + return gen_images + + def run(cmd, payload): + # do things like create a new B and interpolate to it + break + +class Listener: + def __init__(self): + self.interpolator = Interpolator() + self.interpolator.build() + + def connect(self): + self.rpc_client = CortexRPC(self.on_get, self.on_set, self.on_ready, self.on_cmd) + + def on_set(self, key, value): + self.interpolator.set_value(key, value) + + def on_get(self): + return self.interpolator.get_state() + + def on_cmd(self, cmd, payload): + print("got command {}".format(cmd)) + self.interpolator.run(cmd, payload) + + def on_ready(self, rpc_client): + print("Ready!") + self.rpc_client = rpc_client + self.rpc_client.send_status('processing', True) + for i in range(99999): + gen_images = self.interpolator.on_step(i): + if gen_images is None: + break + out_img = vs.data2pil(gen_images[0]) + if out_img is not None: + if out_img.resize_before_sending: + out_img.resize((256, 256), Image.BICUBIC) + self.rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, img_to_send, data_opt.output_format) + self.rpc_client.send_status('processing', False) + sess.close() -# Convert generated image to channels_first. -gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) +if __name__ == '__main__': + listener = Listener() + listener.connect() # layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer # gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) @@ -76,36 +182,3 @@ gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) # encoding = tf.get_variable(name='encoding', dtype=tf.float32, # shape=[BATCH_SIZE,] + ENC_SHAPE) # tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) - -IMG_SHAPE = gen_img.get_shape().as_list()[1:] - -t = time.time() - -def on_step(): - # local variables to update: - # t, sin_speed, sin_t - # variables to assign: - # z_a, z_b, z_n - # label_a, label_b, label_n - # sin_t, sin_noise, sin_scale, sin_amount - # sin_noise_a, sin_noise_b, sin_noise_n - # sess.run([ - # target.assign(image_batch) - # ]) - # sess.run(label.assign(label_batch)) - gen_time = time.time() - gen_images = sess.run(gen_img) - print("Generation time: {:.1f}s".format(time.time() - gen_time)) - # convert to png and send this back... - out_img = vs.data2img(image_batch[0]) - pass - -def run_live(): - while True: - if on_step(): - break - sess.close() - -if __name__ == '__main__': - listener = Listener(opt, run_live) - listener.connect() diff --git a/inversion/visualize.py b/inversion/visualize.py index d17fe13..2461458 100644 --- a/inversion/visualize.py +++ b/inversion/visualize.py @@ -36,6 +36,9 @@ def data2img(data): rescaled = np.clip(rescaled, 0, 255) return np.rint(rescaled).astype('uint8') +def data2pil(data); + return Image.fromarray(data2img(data), mode='RGB') + def interleave(a, b): res = np.empty([a.shape[0] + b.shape[0]] + list(a.shape[1:]), dtype=a.dtype) res[0::2] = a -- cgit v1.2.3-70-g09d2