summaryrefslogtreecommitdiff
path: root/inversion
diff options
context:
space:
mode:
Diffstat (limited to 'inversion')
-rw-r--r--inversion/listener.py25
-rw-r--r--inversion/live.py111
-rw-r--r--inversion/params_opt.py39
3 files changed, 175 insertions, 0 deletions
diff --git a/inversion/listener.py b/inversion/listener.py
new file mode 100644
index 0000000..a43c33c
--- /dev/null
+++ b/inversion/listener.py
@@ -0,0 +1,25 @@
+import os
+import sys
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../live-cortex/rpc/'))
+from rpc import CortexRPC
+
+class Listener:
+ def __init__(self, opt, run_live):
+ self.opt = opt
+ self.run_live = run_live
+ def _set_fn(self, key, value):
+ self.opt[key] = value
+ def _get_fn(self):
+ return self.opt
+ def _cmd_fn(self, cmd, payload):
+ print("got command {}".format(cmd))
+ if cmd == '':
+ pass
+ else:
+ pass
+ def _ready_fn(self, rpc_client):
+ print("Ready!")
+ self.rpc_client = rpc_client
+ self.run_live(self.opt, rpc_client)
+ def connect(self):
+ self.rpc_client = CortexRPC(self._get_fn, self._set_fn, self._ready_fn, self._cmd_fn)
diff --git a/inversion/live.py b/inversion/live.py
new file mode 100644
index 0000000..672853a
--- /dev/null
+++ b/inversion/live.py
@@ -0,0 +1,111 @@
+import os
+import sys
+import glob
+import h5py
+import numpy as np
+import params
+import PIL
+import tensorflow as tf
+import tensorflow_probability as tfp
+import tensorflow_hub as hub
+import time
+import visualize as vs
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from listener import Listener
+from params import Params
+from params_opt import ParamsOpt
+
+params = Params('params_dense.json')
+opt = ParamsOpt()
+
+# --------------------------
+# Make directories.
+# --------------------------
+OUTPUT_DIR = os.path.join('output', tag)
+if not os.path.exists(OUTPUT_DIR):
+ os.makedirs(OUTPUT_DIR)
+
+# --------------------------
+# Load Graph.
+# --------------------------
+generator = hub.Module(str(params.generator_path))
+
+gen_signature = 'generator'
+if 'generator' not in generator.get_signature_names():
+ gen_signature = 'default'
+
+input_info = generator.get_input_info_dict(gen_signature)
+COND_GAN = True
+
+BATCH_SIZE = 1
+Z_DIM = input_info['z'].get_shape().as_list()[1]
+N_CLASS = input_info['y'].get_shape().as_list()[1]
+
+def sin(key, shape):
+ # speed should be computed outside of tensorflow
+ # (so we can recursively update t = last_t + speed)
+ noise, noise_a, noise_b, noise_n = lerp('sin_noise', shape)
+ scale = tf.get_variable(name=key + '_scale', dtype=tf.float32, shape=(1,))
+ t = tf.get_variable(name=key + '_t', dtype=tf.float32, shape=(1,))
+ out = tf.sin(t + noise) * scale
+ return out, t, scale, noise_a, noise_b, noise_n
+
+def lerp(key, shape):
+ a = tf.get_variable(name=key + '_a', dtype=tf.float32, shape=shape)
+ b = tf.get_variable(name=key + '_b', dtype=tf.float32, shape=shape)
+ n = tf.get_variable(name=key + '_n', dtype=tf.float32, shape=(1,))
+ out = a * (1 - n) + b * n
+ return out, a, b, n
+
+lerp_z, z_a, z_b, z_n = lerp('latent', [BATCH_SIZE, Z_DIM])
+sin_z, sin_t, sin_scale, sin_noise_a, sin_noise_b, sin_noise_n = lerp('sin_z', [BATCH_SIZE, Z_DIM])
+lerp_label, label_a, label_b, label_n = lerp('label', [BATCH_SIZE, N_CLASS])
+
+gen_in = dict(params.generator_fixed_inputs)
+gen_in['z'] = lerp_z + sin_z
+gen_in['y'] = lerp_label
+gen_img = generator(gen_in, signature=gen_signature)
+
+# Convert generated image to channels_first.
+gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
+
+# layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer
+# gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name)
+# ENC_SHAPE = gen_encoding.get_shape().as_list()[1:]
+# encoding = tf.get_variable(name='encoding', dtype=tf.float32,
+# shape=[BATCH_SIZE,] + ENC_SHAPE)
+# tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding))
+
+IMG_SHAPE = gen_img.get_shape().as_list()[1:]
+
+t = time.time()
+
+def on_step():
+ # local variables to update:
+ # t, sin_speed, sin_t
+ # variables to assign:
+ # z_a, z_b, z_n
+ # label_a, label_b, label_n
+ # sin_t, sin_noise, sin_scale, sin_amount
+ # sin_noise_a, sin_noise_b, sin_noise_n
+ # sess.run([
+ # target.assign(image_batch)
+ # ])
+ # sess.run(label.assign(label_batch))
+ gen_time = time.time()
+ gen_images = sess.run(gen_img)
+ print("Generation time: {:.1f}s".format(time.time() - gen_time))
+ # convert to png and send this back...
+ out_img = vs.data2img(image_batch[0])
+ pass
+
+def run_live():
+ while True:
+ if on_step():
+ break
+ sess.close()
+
+if __name__ == '__main__':
+ listener = Listener(opt, run_live)
+ listener.connect()
diff --git a/inversion/params_opt.py b/inversion/params_opt.py
new file mode 100644
index 0000000..7333cb9
--- /dev/null
+++ b/inversion/params_opt.py
@@ -0,0 +1,39 @@
+# ------------------------------------------------------------------------------
+# Util class for hyperparams.
+# ------------------------------------------------------------------------------
+
+import json
+
+class Params():
+ """Class that loads hyperparameters from a json file."""
+
+ def __init__(self, json_path):
+ self.update(json_path)
+
+ def __setitem__(self, key, item):
+ self.__dict__[key] = item
+
+ def __getitem__(self, key):
+ return self.__dict__[key]
+
+ def __repr__(self):
+ return repr(self.__dict__)
+
+ def __len__(self):
+ return len(self.__dict__)
+
+ def __delitem__(self, key):
+ del self.__dict__[key]
+
+ def save(self, json_path):
+ """Saves parameters to json file."""
+ with open(json_path, 'w') as f:
+ json.dump(self.__dict__, f, indent=4)
+
+ def update(self, *args, **kwargs):
+ return self.__dict__.update(*args, **kwargs)
+
+ # @property
+ # def dict(self):
+ # """Gives dict-like access to Params instance."""
+ # return self.__dict__