diff options
Diffstat (limited to 'inversion/live.py')
| -rw-r--r-- | inversion/live.py | 183 |
1 files changed, 128 insertions, 55 deletions
diff --git a/inversion/live.py b/inversion/live.py index 672853a..8ad38a2 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -4,13 +4,14 @@ import glob import h5py import numpy as np import params -import PIL import tensorflow as tf import tensorflow_probability as tfp import tensorflow_hub as hub import time import visualize as vs tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../live-cortex/rpc/')) +from rpc import CortexRPC from listener import Listener from params import Params @@ -29,6 +30,10 @@ if not os.path.exists(OUTPUT_DIR): # -------------------------- # Load Graph. # -------------------------- +sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) +sess.run(tf.global_variables_initializer()) +sess.run(tf.tables_initializer()) + generator = hub.Module(str(params.generator_path)) gen_signature = 'generator' @@ -42,70 +47,138 @@ BATCH_SIZE = 1 Z_DIM = input_info['z'].get_shape().as_list()[1] N_CLASS = input_info['y'].get_shape().as_list()[1] -def sin(key, shape): - # speed should be computed outside of tensorflow - # (so we can recursively update t = last_t + speed) - noise, noise_a, noise_b, noise_n = lerp('sin_noise', shape) - scale = tf.get_variable(name=key + '_scale', dtype=tf.float32, shape=(1,)) - t = tf.get_variable(name=key + '_t', dtype=tf.float32, shape=(1,)) - out = tf.sin(t + noise) * scale - return out, t, scale, noise_a, noise_b, noise_n -def lerp(key, shape): - a = tf.get_variable(name=key + '_a', dtype=tf.float32, shape=shape) - b = tf.get_variable(name=key + '_b', dtype=tf.float32, shape=shape) - n = tf.get_variable(name=key + '_n', dtype=tf.float32, shape=(1,)) +def sin(opts, key, shape): + noise = lerp(opts, key + '_noise', shape) + scale = InterpolatorParam(name=key + '_scale') + time = opts['global']['time'].variable + out = tf.sin(time + noise) * scale + opts[key] = { + 'scale': scale, + } + return out + +def lerp(opts, key, shape): + a = InterpolatorParam(name=key + '_a', shape=shape) + b = InterpolatorParam(name=key + '_b', shape=shape) + n = InterpolatorParam(name=key + '_n') out = a * (1 - n) + b * n - return out, a, b, n + opts[key] = { + 'a': a, + 'b': b, + 'n': n, + } + return out -lerp_z, z_a, z_b, z_n = lerp('latent', [BATCH_SIZE, Z_DIM]) -sin_z, sin_t, sin_scale, sin_noise_a, sin_noise_b, sin_noise_n = lerp('sin_z', [BATCH_SIZE, Z_DIM]) -lerp_label, label_a, label_b, label_n = lerp('label', [BATCH_SIZE, N_CLASS]) +class InterpolatorParam: + def __init__(self, name, dtype=tf.float32, shape=(), value=None): + self.scalar = shape == () + self.shape = shape + self.value = np.zeros(shape) if value == None else + self.variable = tf.Variable(self.value, name=name, dtype=dtype, shape=shape) -gen_in = dict(params.generator_fixed_inputs) -gen_in['z'] = lerp_z + sin_z -gen_in['y'] = lerp_label -gen_img = generator(gen_in, signature=gen_signature) + def assign(value): + self.value = value + return self.variable.assign(value) -# Convert generated image to channels_first. -gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) + def randomize(self): + return self.assign(np.random.normal(size=self.shape)) -# layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer -# gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) -# ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] -# encoding = tf.get_variable(name='encoding', dtype=tf.float32, -# shape=[BATCH_SIZE,] + ENC_SHAPE) -# tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) +class Interpolator: + def __init__(self): + self.opts = {} + self.t = time.time() + + def build(self): + opts = { + 'global': { + 'time': InterpolatorParam(name='t', value=time.time()) + }, + } + lerp_z = lerp(opts, 'latent', [BATCH_SIZE, Z_DIM]) + sin_z = sin(opts, 'sin_z', [BATCH_SIZE, Z_DIM]) + lerp_label = lerp(opts, 'label', [BATCH_SIZE, N_CLASS]) + opts['threshold'] = InterpolatorParam('threshold', value=1.0) + + gen_in = {} + gen_in['threshold'] = opts['threshold'].variable + gen_in['z'] = lerp_z + sin_z + gen_in['y'] = lerp_label + gen_img = generator(gen_in, signature=gen_signature) + + # Convert generated image to channels_first. + self.gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) + + for group, lookup in self.opts.items(): + for key, param in group.items(): + if param.scalar: + param.assign().eval(session=sess) + else: + param.randomize().eval(session=sess) + + def get_state(self): + opt = {} + for group, lookup in self.opts.items(): + for key, param in group.items(): + if param.scalar: + opt[group][key] = param.value + return opt -IMG_SHAPE = gen_img.get_shape().as_list()[1:] + def set_value(self, key, value): + self.opts[key].assign(value).eval(session=sess) -t = time.time() + def on_step(i): + gen_time = time.time() + self.opts['global']['time'].assign(gen_time).eval(session=sess) + gen_images = sess.run(self.gen_img) + print("Generation time: {:.1f}s".format(time.time() - gen_time)) + return gen_images -def on_step(): - # local variables to update: - # t, sin_speed, sin_t - # variables to assign: - # z_a, z_b, z_n - # label_a, label_b, label_n - # sin_t, sin_noise, sin_scale, sin_amount - # sin_noise_a, sin_noise_b, sin_noise_n - # sess.run([ - # target.assign(image_batch) - # ]) - # sess.run(label.assign(label_batch)) - gen_time = time.time() - gen_images = sess.run(gen_img) - print("Generation time: {:.1f}s".format(time.time() - gen_time)) - # convert to png and send this back... - out_img = vs.data2img(image_batch[0]) - pass + def run(cmd, payload): + # do things like create a new B and interpolate to it + break -def run_live(): - while True: - if on_step(): - break - sess.close() +class Listener: + def __init__(self): + self.interpolator = Interpolator() + self.interpolator.build() + + def connect(self): + self.rpc_client = CortexRPC(self.on_get, self.on_set, self.on_ready, self.on_cmd) + + def on_set(self, key, value): + self.interpolator.set_value(key, value) + + def on_get(self): + return self.interpolator.get_state() + + def on_cmd(self, cmd, payload): + print("got command {}".format(cmd)) + self.interpolator.run(cmd, payload) + + def on_ready(self, rpc_client): + print("Ready!") + self.rpc_client = rpc_client + self.rpc_client.send_status('processing', True) + for i in range(99999): + gen_images = self.interpolator.on_step(i): + if gen_images is None: + break + out_img = vs.data2pil(gen_images[0]) + if out_img is not None: + if out_img.resize_before_sending: + out_img.resize((256, 256), Image.BICUBIC) + self.rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, img_to_send, data_opt.output_format) + self.rpc_client.send_status('processing', False) + sess.close() if __name__ == '__main__': - listener = Listener(opt, run_live) + listener = Listener() listener.connect() + +# layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer +# gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) +# ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] +# encoding = tf.get_variable(name='encoding', dtype=tf.float32, +# shape=[BATCH_SIZE,] + ENC_SHAPE) +# tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) |
