summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cli/app/commands/biggan/live.py12
-rw-r--r--cli/app/search/live.py361
-rw-r--r--cli/app/search/search_class.py4
-rw-r--r--cli/app/search/search_dense.py5
-rw-r--r--cli/app/settings/app_cfg.py1
-rw-r--r--inversion/live.py8
6 files changed, 386 insertions, 5 deletions
diff --git a/cli/app/commands/biggan/live.py b/cli/app/commands/biggan/live.py
new file mode 100644
index 0000000..38d5e45
--- /dev/null
+++ b/cli/app/commands/biggan/live.py
@@ -0,0 +1,12 @@
+import click
+
+@click.command('')
+@click.pass_context
+def cli(ctx):
+ """
+ Run the BigGAN live
+ """
+ from app.search.live import Listener
+
+ listener = Listener()
+ listener.connect()
diff --git a/cli/app/search/live.py b/cli/app/search/live.py
new file mode 100644
index 0000000..0b97f15
--- /dev/null
+++ b/cli/app/search/live.py
@@ -0,0 +1,361 @@
+import os
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
+import sys
+import glob
+import h5py
+import numpy as np
+import params
+import json
+import tensorflow as tf
+import tensorflow_probability as tfp
+import tensorflow_hub as hub
+import time
+import random
+from scipy.stats import truncnorm
+from PIL import Image
+from url.parse import parse_qs
+import app.search.visualize as vs
+from app.search.json import params_dense_dict
+from app.utils.file_utils import load_pickle
+from app.settings import app_cfg
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../../../live-cortex/rpc/'))
+from rpc import CortexRPC
+
+FPS = 25
+
+params = params_dense_dict('live')
+
+# --------------------------
+# Make directories.
+# --------------------------
+tag = "test"
+OUTPUT_DIR = os.path.join('output', tag)
+if not os.path.exists(OUTPUT_DIR):
+ os.makedirs(OUTPUT_DIR)
+
+# --------------------------
+# Load Graph.
+# --------------------------
+print("Loading module...")
+generator = hub.Module(str(params.generator_path))
+print("Loaded!")
+
+gen_signature = 'generator'
+if 'generator' not in generator.get_signature_names():
+ gen_signature = 'default'
+
+input_info = generator.get_input_info_dict(gen_signature)
+
+BATCH_SIZE = 1
+Z_DIM = input_info['z'].get_shape().as_list()[1]
+N_CLASS = input_info['y'].get_shape().as_list()[1]
+
+# --------------------------
+# Utils
+# --------------------------
+
+def clamp(n, a=0, b=1):
+ if n < a:
+ return a
+ if n > b:
+ return b
+ return n
+
+# --------------------------
+# Initializers
+# --------------------------
+
+def label_sampler(num_classes=1, shape=(BATCH_SIZE, N_CLASS,)):
+ label = np.zeros(shape)
+ for i in range(shape[0]):
+ for _ in range(int(num_classes)):
+ j = random.randint(0, shape[1]-1)
+ label[i, j] = random.random()
+ print("class: {} {}".format(j, label[i, j]))
+ label[i] /= label[i].sum()
+ return label
+
+def truncated_z_sample(shape=(BATCH_SIZE, Z_DIM,), truncation=1.0):
+ values = truncnorm.rvs(-2, 2, size=shape)
+ return truncation * values
+
+def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)):
+ return np.random.normal(size=shape)
+
+# --------------------------
+# More complex ops
+# --------------------------
+
+class SinParam:
+ def __init__(self, name, shape, datatype="noise"):
+ noise = LerpParam(name + '_noise', shape=shape, datatype=datatype)
+ orbit_radius = InterpolatorParam(name=name + '_radius', value=0.25)
+ orbit_speed = InterpolatorParam(name=name + '_speed', value=1.0)
+ orbit_time = InterpolatorParam(name=name + '_time', value=0.0)
+ output = tf.math.sin(orbit_time.variable + noise.output) * orbit_radius.variable
+ interpolator.sin_params[name] = self
+ self.name = name
+ self.orbit_speed = orbit_speed
+ self.orbit_time = orbit_time
+ self.output = output
+
+ def update(self, dt):
+ self.orbit_time.value += self.orbit_speed.value * dt
+
+class LerpParam:
+ def __init__(self, name, shape, a_in=None, b_in=None, datatype="noise"):
+ if a_in is not None and b_in is not None:
+ a = InterpolatorVariable(variable=a_in)
+ b = InterpolatorVariable(variable=b_in)
+ else:
+ a = InterpolatorParam(name=name + '_a', shape=shape, datatype=datatype)
+ b = InterpolatorParam(name=name + '_b', shape=shape, datatype=datatype)
+ n = InterpolatorParam(name=name + '_n', value=0.0)
+ speed = InterpolatorParam(name=name + '_speed', value=1.0)
+ output = a.variable * (1 - n.variable) + b.variable * n.variable
+ interpolator.lerp_params[name] = self
+ self.name = name
+ self.a = a
+ self.b = b
+ self.n = n
+ self.speed = speed
+ self.output = output
+ self.direction = 0
+
+ def switch(self, target_value=None):
+ if self.n.value > 0.5:
+ target_param = self.a
+ self.direction = -1
+ else:
+ target_param = self.b
+ self.direction = 1
+ if target_value is None:
+ target_param.randomize()
+ else:
+ target_param.assign(target_value)
+
+ def update(self, dt):
+ if self.direction != 0:
+ self.n.value = clamp(self.n.value + self.direction * self.speed.value * dt)
+ if self.n.value == 0 or self.n.value == 1:
+ self.direction = 0
+
+# --------------------------
+# Placeholder params
+# --------------------------
+
+class InterpolatorParam:
+ def __init__(self, name, dtype=tf.float32, shape=(), value=None, datatype="float"):
+ self.scalar = shape == ()
+ self.shape = shape
+ self.datatype = datatype
+ if datatype == "float":
+ self.assign(value or 0.0)
+ else:
+ self.randomize()
+ self.variable = tf.placeholder(dtype=dtype, shape=shape)
+ interpolator.opts[name] = self
+
+ def assign(self, value):
+ if self.datatype == 'float':
+ self.value = float(value)
+ else:
+ self.value = value
+
+ def randomize(self):
+ if self.datatype == 'noise':
+ val = truncated_z_sample(shape=self.shape, truncation=interpolator.opts['truncation'].value)
+ elif self.datatype == 'label':
+ val = label_sampler(shape=self.shape, num_classes=interpolator.opts['num_classes'].value)
+ elif self.datatype == 'encoding'
+ val = np.zeros(self.shape)
+ else:
+ val = 0.0
+ self.assign(val)
+
+class InterpolatorVariable:
+ def __init__(self, variable):
+ self.scalar = False
+ self.variable = variable
+
+ def assign(self):
+ pass
+
+ def randomize(self):
+ pass
+
+# --------------------------
+# Interpolator graph
+# --------------------------
+
+class Interpolator:
+ def __init__(self):
+ self.opts = {}
+ self.sin_params = {}
+ self.lerp_params = {}
+
+ def build(self):
+ InterpolatorParam(name='truncation', value=1.0)
+ InterpolatorParam(name='num_classes', value=1.0)
+ abs_zoom = InterpolatorParam(name='abs_zoom', value=1.0)
+ lerp_z = LerpParam('latent', shape=[BATCH_SIZE, Z_DIM], datatype="noise")
+ sin_z = SinParam('orbit', shape=[BATCH_SIZE, Z_DIM], datatype="noise")
+ lerp_label = LerpParam('label', shape=[BATCH_SIZE, N_CLASS], datatype="label")
+
+ z_sum = lerp_z.output + sin_z.output
+ z_abs = z_sum / tf.abs(z_sum) * abs_zoom.variable
+ z_mix = LerpParam('abs_mix', a_in=z_sum, b_in=z_abs, shape=[BATCH_SIZE, Z_DIM], datatype="input")
+
+ # self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], datatype='noise')
+ # self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], datatype='label')
+
+ gen_in = {}
+ gen_in['truncation'] = 1.0 # self.opts['truncation'].variable
+ gen_in['z'] = z_mix.output
+ gen_in['y'] = lerp_label.output
+ self.gen_img = generator(gen_in, signature=gen_signature)
+
+ gen_layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer
+ encoding_latent = tf.get_default_graph().get_tensor_by_name(gen_layer_name)
+
+ encoding_shape = [1,] + gen_encoding.get_shape().as_list()[1:]
+ encoding_stored = LerpParam('encoding_stored', shape=encoding_shape, datatype="encoding")
+ encoding_mix = LerpParam('encoding_mix', a_in=encoding_latent, b_in=encoding_stored, shape=encoding_shape, datatype="encoding")
+ tf.contrib.graph_editor.swap_ts(gen_encoding, encoding_mix)
+
+ sys.stderr.write("Sin params: {}\n".format(", ".join(self.sin_params.keys())))
+ sys.stderr.write("Lerp params: {}\n".format(", ".join(self.lerp_params.keys())))
+ sys.stderr.write("Opts: {}\n".format(", ".join(self.opts.keys())))
+
+ def get_feed_dict(self):
+ opt = {}
+ for param in self.opts.values():
+ opt[param.variable] = param.value
+ return opt
+
+ def get_state(self):
+ opt = {}
+ for key, param in self.opts.items():
+ if param.scalar:
+ if type(param.value) is np.ndarray:
+ sys.stderr.write('{} is ndarray\n'.format(key))
+ opt[key] = param.value.tolist()
+ else:
+ opt[key] = param.value
+ return opt
+
+ def set_value(self, key, value):
+ if key in self.opts:
+ self.opts[key].assign(float(value))
+ else:
+ sys.stderr.write('{} not a valid option\n'.format(key))
+
+ def set_category(self, category):
+ print("Set category: {}".format(category))
+ categories = category.split(" ")
+ label = np.zeros((BATCH_SIZE, N_CLASS,))
+ for category in categories:
+ index = int(category)
+ if index > 0 and index < N_CLASS:
+ label[0, index] = 1.0
+ label[0] /= label[0].sum()
+ self.lerp_params['label'].switch(target_value=label)
+
+ def set_encoding(self, opt):
+ next_id = opt['id']
+ data = load_pickle(os.path.join(app_cfg.DIR_VECTORS, "file_{}.pkl".format(next_id)))
+ encoding_stored = self.lerp_params['encoding_stored']
+ encoding_mix = self.lerp_params['encoding_mix']
+ # if we're showing an encoding already, lerp to the next one
+ if encoding_mix.n.value > 0:
+ encoding_stored.switch(target_value=data['encoding'])
+ # otherwise (we're showing the latent)...
+ else:
+ # jump to the stored encoding, then switch
+ if encoding_stored.n.value < 0.5:
+ encoding_stored.n.value = 0
+ encoding_stored.a.assign(data['encoding'])
+ else:
+ encoding_stored.n.value = 1
+ encoding_stored.b.assign(data['encoding'])
+ encoding_mix.switch()
+
+ def on_step(self, i, dt, sess):
+ for param in self.sin_params.values():
+ param.update(dt)
+ for param in self.lerp_params.values():
+ param.update(dt)
+ gen_images = sess.run(self.gen_img, feed_dict=self.get_feed_dict())
+ return gen_images
+
+ def run(self, cmd, payload):
+ print("Command: {} {}".format(cmd, payload))
+ if cmd == 'switch' and payload in self.lerp_params:
+ self.lerp_params[payload].switch()
+ if cmd == 'setCategory':
+ self.set_category(payload)
+ if cmd == 'setEncoding':
+ self.set_encoding(parse_qs(payload))
+ pass
+
+# --------------------------
+# RPC Listener
+# --------------------------
+
+interpolator = Interpolator()
+
+class Listener:
+ def connect(self):
+ self.rpc_client = CortexRPC(self.on_get, self.on_set, self.on_ready, self.on_cmd)
+
+ def on_set(self, key, value):
+ print("{}: {} {}".format(key, str(type(value)), value))
+ interpolator.set_value(key, value)
+
+ def on_get(self):
+ state = interpolator.get_state()
+ sys.stderr.write(json.dumps(state) + "\n")
+ sys.stderr.flush()
+ return state
+
+ def on_cmd(self, cmd, payload):
+ print("got command {}".format(cmd))
+ interpolator.run(cmd, payload)
+
+ def on_ready(self, rpc_client):
+ self.rpc_client = rpc_client
+ print("Starting session...")
+ self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+ self.sess.run(tf.global_variables_initializer())
+ self.sess.run(tf.tables_initializer())
+ print("Building interpolator...")
+ interpolator.build()
+ self.rpc_client.send_status('processing', True)
+ dt = 1 / FPS
+ for i in range(99999):
+ if i == 0:
+ print("Loading network...")
+ elif i == 1:
+ print("Processing!")
+ gen_time = time.time()
+ gen_images = interpolator.on_step(i, dt, self.sess)
+ if gen_images is None:
+ print("Exiting...")
+ break
+ if (i % 100) == 0 or i == 1:
+ # print(gen_images.shape)
+ print("Step {}. Generation time: {:.2f}s".format(i, time.time() - gen_time))
+ out_img = vs.data2pil(gen_images[0])
+ if out_img is not None:
+ #if out_img.resize_before_sending:
+ img_to_send = out_img.resize((256, 256), Image.BICUBIC)
+ meta = {
+ 'i': i,
+ 'sequence_i': i,
+ 'skip_i': 0,
+ 'sequence_len': 99999,
+ }
+ self.rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, img_to_send, 'jpg')
+ self.rpc_client.send_status('processing', False)
+ self.sess.close()
diff --git a/cli/app/search/search_class.py b/cli/app/search/search_class.py
index 8825ca5..875bc75 100644
--- a/cli/app/search/search_class.py
+++ b/cli/app/search/search_class.py
@@ -191,9 +191,9 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la
print('iter: {}, loss: {}'.format(i, curr_loss))
if i > 0:
if opt_stochastic_clipping and (i % opt_clip_interval) == 0 and i < opt_steps * 0.45:
- sess.run(clip_latent, { clipped_alpha: i / opt_steps })
+ sess.run(clip_latent, { clipped_alpha: (i / opt_steps) ** 2 })
if opt_label_clipping and (i % opt_clip_interval) == 0:
- sess.run(clip_labels, { normalized_alpha: i / opt_steps })
+ sess.run(clip_labels, { normalized_alpha: (i / opt_steps) ** 2 })
if opt_video and opt_snapshot_interval != 0 and (i % opt_snapshot_interval) == 0:
phi_guess = sess.run(output)
guess_im = imgrid(imconvert_uint8(phi_guess), cols=1)
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py
index d5cbf64..5bf77fc 100644
--- a/cli/app/search/search_dense.py
+++ b/cli/app/search/search_dense.py
@@ -50,11 +50,10 @@ def find_dense_embedding_for_images(params):
SAMPLE_SIZE = params.sample_size
LOGS_DIR = os.path.join(params.path, LATENT_TAG, 'logs')
SAMPLES_DIR = os.path.join(params.path, LATENT_TAG, 'samples')
- VECTOR_DIR = os.path.join(params.path, 'vectors')
os.makedirs(LOGS_DIR, exist_ok=True)
os.makedirs(SAMPLES_DIR, exist_ok=True)
- os.makedirs(VECTOR_DIR, exist_ok=True)
+ os.makedirs(app_cfg.DIR_VECTORS, exist_ok=True)
def one_hot(values):
return np.eye(N_CLASS)[values]
@@ -426,7 +425,7 @@ def find_dense_embedding_for_images(params):
data = upload_bytes_to_cortex(params.folder_id, sample_fn + "-inverse.png", fp, "image/png")
if data is not None:
file_id = data['id']
- fp_out_pkl = os.path.join(vector_dir, "file_{}.pkl".format(file_id))
+ fp_out_pkl = os.path.join(app_cfg.DIR_VECTORS, "file_{}.pkl".format(file_id))
out_data = {
'id': file_id,
'sample_fn': sample_fn,
diff --git a/cli/app/settings/app_cfg.py b/cli/app/settings/app_cfg.py
index bade59c..6ddd445 100644
--- a/cli/app/settings/app_cfg.py
+++ b/cli/app/settings/app_cfg.py
@@ -31,6 +31,7 @@ SELF_CWD = os.path.dirname(os.path.realpath(__file__)) # Script CWD
DIR_APP = str(Path(SELF_CWD).parent.parent.parent)
DIR_IMAGENET = join(DIR_APP, 'data_store/imagenet')
DIR_INVERSES = join(DIR_APP, 'data_store/inverses')
+DIR_VECTORS = join(DIR_APP, 'data_store/vectors')
DIR_INPUTS = join(DIR_APP, 'data_store/inputs')
DIR_OUTPUTS = join(DIR_APP, 'data_store/outputs')
FP_MODELZOO = join(DIR_APP, 'modelzoo/modelzoo.yaml')
diff --git a/inversion/live.py b/inversion/live.py
index e8001ef..6642d0a 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -13,6 +13,7 @@ import time
import random
from scipy.stats import truncnorm
from PIL import Image
+from url.parse import parse_qs
import visualize as vs
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../live-cortex/rpc/'))
@@ -260,6 +261,11 @@ class Interpolator:
label[0] /= label[0].sum()
self.lerp_params['label'].switch(target_value=label)
+ def set_encoding(self, opt):
+ next_id = opt['id']
+ should_lerp = opt['lerp']
+
+
def on_step(self, i, dt, sess):
for param in self.sin_params.values():
param.update(dt)
@@ -274,6 +280,8 @@ class Interpolator:
self.lerp_params[payload].switch()
if cmd == 'setCategory':
self.set_category(payload)
+ if cmd == 'setEncoding':
+ self.set_encoding(parse_qs(payload))
pass
# --------------------------