summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--inversion/listener.py7
-rw-r--r--inversion/live.py183
-rw-r--r--inversion/visualize.py3
3 files changed, 133 insertions, 60 deletions
diff --git a/inversion/listener.py b/inversion/listener.py
index a43c33c..fc04586 100644
--- a/inversion/listener.py
+++ b/inversion/listener.py
@@ -4,7 +4,7 @@ sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../..
from rpc import CortexRPC
class Listener:
- def __init__(self, opt, run_live):
+ def __init__(self, opt, run_live, run_cmd):
self.opt = opt
self.run_live = run_live
def _set_fn(self, key, value):
@@ -13,10 +13,7 @@ class Listener:
return self.opt
def _cmd_fn(self, cmd, payload):
print("got command {}".format(cmd))
- if cmd == '':
- pass
- else:
- pass
+ run_cmd(cmd, payload)
def _ready_fn(self, rpc_client):
print("Ready!")
self.rpc_client = rpc_client
diff --git a/inversion/live.py b/inversion/live.py
index 672853a..8ad38a2 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -4,13 +4,14 @@ import glob
import h5py
import numpy as np
import params
-import PIL
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_hub as hub
import time
import visualize as vs
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../live-cortex/rpc/'))
+from rpc import CortexRPC
from listener import Listener
from params import Params
@@ -29,6 +30,10 @@ if not os.path.exists(OUTPUT_DIR):
# --------------------------
# Load Graph.
# --------------------------
+sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+sess.run(tf.global_variables_initializer())
+sess.run(tf.tables_initializer())
+
generator = hub.Module(str(params.generator_path))
gen_signature = 'generator'
@@ -42,70 +47,138 @@ BATCH_SIZE = 1
Z_DIM = input_info['z'].get_shape().as_list()[1]
N_CLASS = input_info['y'].get_shape().as_list()[1]
-def sin(key, shape):
- # speed should be computed outside of tensorflow
- # (so we can recursively update t = last_t + speed)
- noise, noise_a, noise_b, noise_n = lerp('sin_noise', shape)
- scale = tf.get_variable(name=key + '_scale', dtype=tf.float32, shape=(1,))
- t = tf.get_variable(name=key + '_t', dtype=tf.float32, shape=(1,))
- out = tf.sin(t + noise) * scale
- return out, t, scale, noise_a, noise_b, noise_n
-def lerp(key, shape):
- a = tf.get_variable(name=key + '_a', dtype=tf.float32, shape=shape)
- b = tf.get_variable(name=key + '_b', dtype=tf.float32, shape=shape)
- n = tf.get_variable(name=key + '_n', dtype=tf.float32, shape=(1,))
+def sin(opts, key, shape):
+ noise = lerp(opts, key + '_noise', shape)
+ scale = InterpolatorParam(name=key + '_scale')
+ time = opts['global']['time'].variable
+ out = tf.sin(time + noise) * scale
+ opts[key] = {
+ 'scale': scale,
+ }
+ return out
+
+def lerp(opts, key, shape):
+ a = InterpolatorParam(name=key + '_a', shape=shape)
+ b = InterpolatorParam(name=key + '_b', shape=shape)
+ n = InterpolatorParam(name=key + '_n')
out = a * (1 - n) + b * n
- return out, a, b, n
+ opts[key] = {
+ 'a': a,
+ 'b': b,
+ 'n': n,
+ }
+ return out
-lerp_z, z_a, z_b, z_n = lerp('latent', [BATCH_SIZE, Z_DIM])
-sin_z, sin_t, sin_scale, sin_noise_a, sin_noise_b, sin_noise_n = lerp('sin_z', [BATCH_SIZE, Z_DIM])
-lerp_label, label_a, label_b, label_n = lerp('label', [BATCH_SIZE, N_CLASS])
+class InterpolatorParam:
+ def __init__(self, name, dtype=tf.float32, shape=(), value=None):
+ self.scalar = shape == ()
+ self.shape = shape
+ self.value = np.zeros(shape) if value == None else
+ self.variable = tf.Variable(self.value, name=name, dtype=dtype, shape=shape)
-gen_in = dict(params.generator_fixed_inputs)
-gen_in['z'] = lerp_z + sin_z
-gen_in['y'] = lerp_label
-gen_img = generator(gen_in, signature=gen_signature)
+ def assign(value):
+ self.value = value
+ return self.variable.assign(value)
-# Convert generated image to channels_first.
-gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
+ def randomize(self):
+ return self.assign(np.random.normal(size=self.shape))
-# layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer
-# gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name)
-# ENC_SHAPE = gen_encoding.get_shape().as_list()[1:]
-# encoding = tf.get_variable(name='encoding', dtype=tf.float32,
-# shape=[BATCH_SIZE,] + ENC_SHAPE)
-# tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding))
+class Interpolator:
+ def __init__(self):
+ self.opts = {}
+ self.t = time.time()
+
+ def build(self):
+ opts = {
+ 'global': {
+ 'time': InterpolatorParam(name='t', value=time.time())
+ },
+ }
+ lerp_z = lerp(opts, 'latent', [BATCH_SIZE, Z_DIM])
+ sin_z = sin(opts, 'sin_z', [BATCH_SIZE, Z_DIM])
+ lerp_label = lerp(opts, 'label', [BATCH_SIZE, N_CLASS])
+ opts['threshold'] = InterpolatorParam('threshold', value=1.0)
+
+ gen_in = {}
+ gen_in['threshold'] = opts['threshold'].variable
+ gen_in['z'] = lerp_z + sin_z
+ gen_in['y'] = lerp_label
+ gen_img = generator(gen_in, signature=gen_signature)
+
+ # Convert generated image to channels_first.
+ self.gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
+
+ for group, lookup in self.opts.items():
+ for key, param in group.items():
+ if param.scalar:
+ param.assign().eval(session=sess)
+ else:
+ param.randomize().eval(session=sess)
+
+ def get_state(self):
+ opt = {}
+ for group, lookup in self.opts.items():
+ for key, param in group.items():
+ if param.scalar:
+ opt[group][key] = param.value
+ return opt
-IMG_SHAPE = gen_img.get_shape().as_list()[1:]
+ def set_value(self, key, value):
+ self.opts[key].assign(value).eval(session=sess)
-t = time.time()
+ def on_step(i):
+ gen_time = time.time()
+ self.opts['global']['time'].assign(gen_time).eval(session=sess)
+ gen_images = sess.run(self.gen_img)
+ print("Generation time: {:.1f}s".format(time.time() - gen_time))
+ return gen_images
-def on_step():
- # local variables to update:
- # t, sin_speed, sin_t
- # variables to assign:
- # z_a, z_b, z_n
- # label_a, label_b, label_n
- # sin_t, sin_noise, sin_scale, sin_amount
- # sin_noise_a, sin_noise_b, sin_noise_n
- # sess.run([
- # target.assign(image_batch)
- # ])
- # sess.run(label.assign(label_batch))
- gen_time = time.time()
- gen_images = sess.run(gen_img)
- print("Generation time: {:.1f}s".format(time.time() - gen_time))
- # convert to png and send this back...
- out_img = vs.data2img(image_batch[0])
- pass
+ def run(cmd, payload):
+ # do things like create a new B and interpolate to it
+ break
-def run_live():
- while True:
- if on_step():
- break
- sess.close()
+class Listener:
+ def __init__(self):
+ self.interpolator = Interpolator()
+ self.interpolator.build()
+
+ def connect(self):
+ self.rpc_client = CortexRPC(self.on_get, self.on_set, self.on_ready, self.on_cmd)
+
+ def on_set(self, key, value):
+ self.interpolator.set_value(key, value)
+
+ def on_get(self):
+ return self.interpolator.get_state()
+
+ def on_cmd(self, cmd, payload):
+ print("got command {}".format(cmd))
+ self.interpolator.run(cmd, payload)
+
+ def on_ready(self, rpc_client):
+ print("Ready!")
+ self.rpc_client = rpc_client
+ self.rpc_client.send_status('processing', True)
+ for i in range(99999):
+ gen_images = self.interpolator.on_step(i):
+ if gen_images is None:
+ break
+ out_img = vs.data2pil(gen_images[0])
+ if out_img is not None:
+ if out_img.resize_before_sending:
+ out_img.resize((256, 256), Image.BICUBIC)
+ self.rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, img_to_send, data_opt.output_format)
+ self.rpc_client.send_status('processing', False)
+ sess.close()
if __name__ == '__main__':
- listener = Listener(opt, run_live)
+ listener = Listener()
listener.connect()
+
+# layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer
+# gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name)
+# ENC_SHAPE = gen_encoding.get_shape().as_list()[1:]
+# encoding = tf.get_variable(name='encoding', dtype=tf.float32,
+# shape=[BATCH_SIZE,] + ENC_SHAPE)
+# tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding))
diff --git a/inversion/visualize.py b/inversion/visualize.py
index d17fe13..2461458 100644
--- a/inversion/visualize.py
+++ b/inversion/visualize.py
@@ -36,6 +36,9 @@ def data2img(data):
rescaled = np.clip(rescaled, 0, 255)
return np.rint(rescaled).astype('uint8')
+def data2pil(data);
+ return Image.fromarray(data2img(data), mode='RGB')
+
def interleave(a, b):
res = np.empty([a.shape[0] + b.shape[0]] + list(a.shape[1:]), dtype=a.dtype)
res[0::2] = a