From b9f515173f480a0468fc54d05bf134ebb15b0328 Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Wed, 8 Jan 2020 20:17:30 +0100 Subject: moving in the live stuff --- cli/app/commands/biggan/live.py | 12 ++ cli/app/search/live.py | 361 ++++++++++++++++++++++++++++++++++++++++ cli/app/search/search_class.py | 4 +- cli/app/search/search_dense.py | 5 +- cli/app/settings/app_cfg.py | 1 + inversion/live.py | 8 + 6 files changed, 386 insertions(+), 5 deletions(-) create mode 100644 cli/app/commands/biggan/live.py create mode 100644 cli/app/search/live.py diff --git a/cli/app/commands/biggan/live.py b/cli/app/commands/biggan/live.py new file mode 100644 index 0000000..38d5e45 --- /dev/null +++ b/cli/app/commands/biggan/live.py @@ -0,0 +1,12 @@ +import click + +@click.command('') +@click.pass_context +def cli(ctx): + """ + Run the BigGAN live + """ + from app.search.live import Listener + + listener = Listener() + listener.connect() diff --git a/cli/app/search/live.py b/cli/app/search/live.py new file mode 100644 index 0000000..0b97f15 --- /dev/null +++ b/cli/app/search/live.py @@ -0,0 +1,361 @@ +import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +import sys +import glob +import h5py +import numpy as np +import params +import json +import tensorflow as tf +import tensorflow_probability as tfp +import tensorflow_hub as hub +import time +import random +from scipy.stats import truncnorm +from PIL import Image +from url.parse import parse_qs +import app.search.visualize as vs +from app.search.json import params_dense_dict +from app.utils.file_utils import load_pickle +from app.settings import app_cfg +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../../../live-cortex/rpc/')) +from rpc import CortexRPC + +FPS = 25 + +params = params_dense_dict('live') + +# -------------------------- +# Make directories. +# -------------------------- +tag = "test" +OUTPUT_DIR = os.path.join('output', tag) +if not os.path.exists(OUTPUT_DIR): + os.makedirs(OUTPUT_DIR) + +# -------------------------- +# Load Graph. +# -------------------------- +print("Loading module...") +generator = hub.Module(str(params.generator_path)) +print("Loaded!") + +gen_signature = 'generator' +if 'generator' not in generator.get_signature_names(): + gen_signature = 'default' + +input_info = generator.get_input_info_dict(gen_signature) + +BATCH_SIZE = 1 +Z_DIM = input_info['z'].get_shape().as_list()[1] +N_CLASS = input_info['y'].get_shape().as_list()[1] + +# -------------------------- +# Utils +# -------------------------- + +def clamp(n, a=0, b=1): + if n < a: + return a + if n > b: + return b + return n + +# -------------------------- +# Initializers +# -------------------------- + +def label_sampler(num_classes=1, shape=(BATCH_SIZE, N_CLASS,)): + label = np.zeros(shape) + for i in range(shape[0]): + for _ in range(int(num_classes)): + j = random.randint(0, shape[1]-1) + label[i, j] = random.random() + print("class: {} {}".format(j, label[i, j])) + label[i] /= label[i].sum() + return label + +def truncated_z_sample(shape=(BATCH_SIZE, Z_DIM,), truncation=1.0): + values = truncnorm.rvs(-2, 2, size=shape) + return truncation * values + +def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)): + return np.random.normal(size=shape) + +# -------------------------- +# More complex ops +# -------------------------- + +class SinParam: + def __init__(self, name, shape, datatype="noise"): + noise = LerpParam(name + '_noise', shape=shape, datatype=datatype) + orbit_radius = InterpolatorParam(name=name + '_radius', value=0.25) + orbit_speed = InterpolatorParam(name=name + '_speed', value=1.0) + orbit_time = InterpolatorParam(name=name + '_time', value=0.0) + output = tf.math.sin(orbit_time.variable + noise.output) * orbit_radius.variable + interpolator.sin_params[name] = self + self.name = name + self.orbit_speed = orbit_speed + self.orbit_time = orbit_time + self.output = output + + def update(self, dt): + self.orbit_time.value += self.orbit_speed.value * dt + +class LerpParam: + def __init__(self, name, shape, a_in=None, b_in=None, datatype="noise"): + if a_in is not None and b_in is not None: + a = InterpolatorVariable(variable=a_in) + b = InterpolatorVariable(variable=b_in) + else: + a = InterpolatorParam(name=name + '_a', shape=shape, datatype=datatype) + b = InterpolatorParam(name=name + '_b', shape=shape, datatype=datatype) + n = InterpolatorParam(name=name + '_n', value=0.0) + speed = InterpolatorParam(name=name + '_speed', value=1.0) + output = a.variable * (1 - n.variable) + b.variable * n.variable + interpolator.lerp_params[name] = self + self.name = name + self.a = a + self.b = b + self.n = n + self.speed = speed + self.output = output + self.direction = 0 + + def switch(self, target_value=None): + if self.n.value > 0.5: + target_param = self.a + self.direction = -1 + else: + target_param = self.b + self.direction = 1 + if target_value is None: + target_param.randomize() + else: + target_param.assign(target_value) + + def update(self, dt): + if self.direction != 0: + self.n.value = clamp(self.n.value + self.direction * self.speed.value * dt) + if self.n.value == 0 or self.n.value == 1: + self.direction = 0 + +# -------------------------- +# Placeholder params +# -------------------------- + +class InterpolatorParam: + def __init__(self, name, dtype=tf.float32, shape=(), value=None, datatype="float"): + self.scalar = shape == () + self.shape = shape + self.datatype = datatype + if datatype == "float": + self.assign(value or 0.0) + else: + self.randomize() + self.variable = tf.placeholder(dtype=dtype, shape=shape) + interpolator.opts[name] = self + + def assign(self, value): + if self.datatype == 'float': + self.value = float(value) + else: + self.value = value + + def randomize(self): + if self.datatype == 'noise': + val = truncated_z_sample(shape=self.shape, truncation=interpolator.opts['truncation'].value) + elif self.datatype == 'label': + val = label_sampler(shape=self.shape, num_classes=interpolator.opts['num_classes'].value) + elif self.datatype == 'encoding' + val = np.zeros(self.shape) + else: + val = 0.0 + self.assign(val) + +class InterpolatorVariable: + def __init__(self, variable): + self.scalar = False + self.variable = variable + + def assign(self): + pass + + def randomize(self): + pass + +# -------------------------- +# Interpolator graph +# -------------------------- + +class Interpolator: + def __init__(self): + self.opts = {} + self.sin_params = {} + self.lerp_params = {} + + def build(self): + InterpolatorParam(name='truncation', value=1.0) + InterpolatorParam(name='num_classes', value=1.0) + abs_zoom = InterpolatorParam(name='abs_zoom', value=1.0) + lerp_z = LerpParam('latent', shape=[BATCH_SIZE, Z_DIM], datatype="noise") + sin_z = SinParam('orbit', shape=[BATCH_SIZE, Z_DIM], datatype="noise") + lerp_label = LerpParam('label', shape=[BATCH_SIZE, N_CLASS], datatype="label") + + z_sum = lerp_z.output + sin_z.output + z_abs = z_sum / tf.abs(z_sum) * abs_zoom.variable + z_mix = LerpParam('abs_mix', a_in=z_sum, b_in=z_abs, shape=[BATCH_SIZE, Z_DIM], datatype="input") + + # self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], datatype='noise') + # self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], datatype='label') + + gen_in = {} + gen_in['truncation'] = 1.0 # self.opts['truncation'].variable + gen_in['z'] = z_mix.output + gen_in['y'] = lerp_label.output + self.gen_img = generator(gen_in, signature=gen_signature) + + gen_layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer + encoding_latent = tf.get_default_graph().get_tensor_by_name(gen_layer_name) + + encoding_shape = [1,] + gen_encoding.get_shape().as_list()[1:] + encoding_stored = LerpParam('encoding_stored', shape=encoding_shape, datatype="encoding") + encoding_mix = LerpParam('encoding_mix', a_in=encoding_latent, b_in=encoding_stored, shape=encoding_shape, datatype="encoding") + tf.contrib.graph_editor.swap_ts(gen_encoding, encoding_mix) + + sys.stderr.write("Sin params: {}\n".format(", ".join(self.sin_params.keys()))) + sys.stderr.write("Lerp params: {}\n".format(", ".join(self.lerp_params.keys()))) + sys.stderr.write("Opts: {}\n".format(", ".join(self.opts.keys()))) + + def get_feed_dict(self): + opt = {} + for param in self.opts.values(): + opt[param.variable] = param.value + return opt + + def get_state(self): + opt = {} + for key, param in self.opts.items(): + if param.scalar: + if type(param.value) is np.ndarray: + sys.stderr.write('{} is ndarray\n'.format(key)) + opt[key] = param.value.tolist() + else: + opt[key] = param.value + return opt + + def set_value(self, key, value): + if key in self.opts: + self.opts[key].assign(float(value)) + else: + sys.stderr.write('{} not a valid option\n'.format(key)) + + def set_category(self, category): + print("Set category: {}".format(category)) + categories = category.split(" ") + label = np.zeros((BATCH_SIZE, N_CLASS,)) + for category in categories: + index = int(category) + if index > 0 and index < N_CLASS: + label[0, index] = 1.0 + label[0] /= label[0].sum() + self.lerp_params['label'].switch(target_value=label) + + def set_encoding(self, opt): + next_id = opt['id'] + data = load_pickle(os.path.join(app_cfg.DIR_VECTORS, "file_{}.pkl".format(next_id))) + encoding_stored = self.lerp_params['encoding_stored'] + encoding_mix = self.lerp_params['encoding_mix'] + # if we're showing an encoding already, lerp to the next one + if encoding_mix.n.value > 0: + encoding_stored.switch(target_value=data['encoding']) + # otherwise (we're showing the latent)... + else: + # jump to the stored encoding, then switch + if encoding_stored.n.value < 0.5: + encoding_stored.n.value = 0 + encoding_stored.a.assign(data['encoding']) + else: + encoding_stored.n.value = 1 + encoding_stored.b.assign(data['encoding']) + encoding_mix.switch() + + def on_step(self, i, dt, sess): + for param in self.sin_params.values(): + param.update(dt) + for param in self.lerp_params.values(): + param.update(dt) + gen_images = sess.run(self.gen_img, feed_dict=self.get_feed_dict()) + return gen_images + + def run(self, cmd, payload): + print("Command: {} {}".format(cmd, payload)) + if cmd == 'switch' and payload in self.lerp_params: + self.lerp_params[payload].switch() + if cmd == 'setCategory': + self.set_category(payload) + if cmd == 'setEncoding': + self.set_encoding(parse_qs(payload)) + pass + +# -------------------------- +# RPC Listener +# -------------------------- + +interpolator = Interpolator() + +class Listener: + def connect(self): + self.rpc_client = CortexRPC(self.on_get, self.on_set, self.on_ready, self.on_cmd) + + def on_set(self, key, value): + print("{}: {} {}".format(key, str(type(value)), value)) + interpolator.set_value(key, value) + + def on_get(self): + state = interpolator.get_state() + sys.stderr.write(json.dumps(state) + "\n") + sys.stderr.flush() + return state + + def on_cmd(self, cmd, payload): + print("got command {}".format(cmd)) + interpolator.run(cmd, payload) + + def on_ready(self, rpc_client): + self.rpc_client = rpc_client + print("Starting session...") + self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) + self.sess.run(tf.global_variables_initializer()) + self.sess.run(tf.tables_initializer()) + print("Building interpolator...") + interpolator.build() + self.rpc_client.send_status('processing', True) + dt = 1 / FPS + for i in range(99999): + if i == 0: + print("Loading network...") + elif i == 1: + print("Processing!") + gen_time = time.time() + gen_images = interpolator.on_step(i, dt, self.sess) + if gen_images is None: + print("Exiting...") + break + if (i % 100) == 0 or i == 1: + # print(gen_images.shape) + print("Step {}. Generation time: {:.2f}s".format(i, time.time() - gen_time)) + out_img = vs.data2pil(gen_images[0]) + if out_img is not None: + #if out_img.resize_before_sending: + img_to_send = out_img.resize((256, 256), Image.BICUBIC) + meta = { + 'i': i, + 'sequence_i': i, + 'skip_i': 0, + 'sequence_len': 99999, + } + self.rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, img_to_send, 'jpg') + self.rpc_client.send_status('processing', False) + self.sess.close() diff --git a/cli/app/search/search_class.py b/cli/app/search/search_class.py index 8825ca5..875bc75 100644 --- a/cli/app/search/search_class.py +++ b/cli/app/search/search_class.py @@ -191,9 +191,9 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la print('iter: {}, loss: {}'.format(i, curr_loss)) if i > 0: if opt_stochastic_clipping and (i % opt_clip_interval) == 0 and i < opt_steps * 0.45: - sess.run(clip_latent, { clipped_alpha: i / opt_steps }) + sess.run(clip_latent, { clipped_alpha: (i / opt_steps) ** 2 }) if opt_label_clipping and (i % opt_clip_interval) == 0: - sess.run(clip_labels, { normalized_alpha: i / opt_steps }) + sess.run(clip_labels, { normalized_alpha: (i / opt_steps) ** 2 }) if opt_video and opt_snapshot_interval != 0 and (i % opt_snapshot_interval) == 0: phi_guess = sess.run(output) guess_im = imgrid(imconvert_uint8(phi_guess), cols=1) diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py index d5cbf64..5bf77fc 100644 --- a/cli/app/search/search_dense.py +++ b/cli/app/search/search_dense.py @@ -50,11 +50,10 @@ def find_dense_embedding_for_images(params): SAMPLE_SIZE = params.sample_size LOGS_DIR = os.path.join(params.path, LATENT_TAG, 'logs') SAMPLES_DIR = os.path.join(params.path, LATENT_TAG, 'samples') - VECTOR_DIR = os.path.join(params.path, 'vectors') os.makedirs(LOGS_DIR, exist_ok=True) os.makedirs(SAMPLES_DIR, exist_ok=True) - os.makedirs(VECTOR_DIR, exist_ok=True) + os.makedirs(app_cfg.DIR_VECTORS, exist_ok=True) def one_hot(values): return np.eye(N_CLASS)[values] @@ -426,7 +425,7 @@ def find_dense_embedding_for_images(params): data = upload_bytes_to_cortex(params.folder_id, sample_fn + "-inverse.png", fp, "image/png") if data is not None: file_id = data['id'] - fp_out_pkl = os.path.join(vector_dir, "file_{}.pkl".format(file_id)) + fp_out_pkl = os.path.join(app_cfg.DIR_VECTORS, "file_{}.pkl".format(file_id)) out_data = { 'id': file_id, 'sample_fn': sample_fn, diff --git a/cli/app/settings/app_cfg.py b/cli/app/settings/app_cfg.py index bade59c..6ddd445 100644 --- a/cli/app/settings/app_cfg.py +++ b/cli/app/settings/app_cfg.py @@ -31,6 +31,7 @@ SELF_CWD = os.path.dirname(os.path.realpath(__file__)) # Script CWD DIR_APP = str(Path(SELF_CWD).parent.parent.parent) DIR_IMAGENET = join(DIR_APP, 'data_store/imagenet') DIR_INVERSES = join(DIR_APP, 'data_store/inverses') +DIR_VECTORS = join(DIR_APP, 'data_store/vectors') DIR_INPUTS = join(DIR_APP, 'data_store/inputs') DIR_OUTPUTS = join(DIR_APP, 'data_store/outputs') FP_MODELZOO = join(DIR_APP, 'modelzoo/modelzoo.yaml') diff --git a/inversion/live.py b/inversion/live.py index e8001ef..6642d0a 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -13,6 +13,7 @@ import time import random from scipy.stats import truncnorm from PIL import Image +from url.parse import parse_qs import visualize as vs tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../live-cortex/rpc/')) @@ -260,6 +261,11 @@ class Interpolator: label[0] /= label[0].sum() self.lerp_params['label'].switch(target_value=label) + def set_encoding(self, opt): + next_id = opt['id'] + should_lerp = opt['lerp'] + + def on_step(self, i, dt, sess): for param in self.sin_params.values(): param.update(dt) @@ -274,6 +280,8 @@ class Interpolator: self.lerp_params[payload].switch() if cmd == 'setCategory': self.set_category(payload) + if cmd == 'setEncoding': + self.set_encoding(parse_qs(payload)) pass # -------------------------- -- cgit v1.2.3-70-g09d2