diff options
| author | jules@lens <julescarbon@gmail.com> | 2019-12-19 01:48:32 +0100 |
|---|---|---|
| committer | jules@lens <julescarbon@gmail.com> | 2019-12-19 01:48:32 +0100 |
| commit | e6e30fa4c0dfedb009ec24d83b9661599a90b4f1 (patch) | |
| tree | f6d087ae19345984af936cb72bafa746666f3fad /inversion | |
| parent | 559966019c005175169ed5a68edce9ce8acc0785 (diff) | |
feed dict is the key
Diffstat (limited to 'inversion')
| -rw-r--r-- | inversion/live.py | 89 |
1 files changed, 44 insertions, 45 deletions
diff --git a/inversion/live.py b/inversion/live.py index 4e42703..cf7f761 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -8,6 +8,7 @@ import tensorflow as tf import tensorflow_probability as tfp import tensorflow_hub as hub import time +from PIL import Image import visualize as vs tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../live-cortex/rpc/')) @@ -15,7 +16,7 @@ from rpc import CortexRPC from params import Params -params = Params('params_dense.json') +params = Params('params_dense-512.json') # -------------------------- # Make directories. @@ -28,10 +29,7 @@ if not os.path.exists(OUTPUT_DIR): # -------------------------- # Load Graph. # -------------------------- -sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) -sess.run(tf.global_variables_initializer()) -sess.run(tf.tables_initializer()) - +print("Loading module...") generator = hub.Module(str(params.generator_path)) gen_signature = 'generator' @@ -39,21 +37,17 @@ if 'generator' not in generator.get_signature_names(): gen_signature = 'default' input_info = generator.get_input_info_dict(gen_signature) -COND_GAN = True BATCH_SIZE = 1 Z_DIM = input_info['z'].get_shape().as_list()[1] N_CLASS = input_info['y'].get_shape().as_list()[1] - def sin(opts, key, shape): noise = lerp(opts, key + '_noise', shape) scale = InterpolatorParam(name=key + '_scale') - time = opts['global']['time'].variable + time = opts['time'].variable out = tf.sin(time + noise) * scale.variable - opts[key] = { - 'scale': scale, - } + opts[key + '_scale'] = scale return out def lerp(opts, key, shape): @@ -61,11 +55,9 @@ def lerp(opts, key, shape): b = InterpolatorParam(name=key + '_b', shape=shape) n = InterpolatorParam(name=key + '_n') out = a.variable * (1 - n.variable) + b.variable * n.variable - opts[key] = { - 'a': a, - 'b': b, - 'n': n, - } + opts[key + '_a'] = a + opts[key + '_b'] = b + opts[key + '_n'] = n return out class InterpolatorParam: @@ -73,11 +65,10 @@ class InterpolatorParam: self.scalar = shape == () self.shape = shape self.value = value or np.zeros(shape) - self.variable = tf.Variable(self.value, name=name, dtype=dtype, shape=shape) + self.variable = tf.placeholder(dtype=dtype, shape=shape) - def assign(value): + def assign(self, value): self.value = value - return self.variable.assign(value) def randomize(self): return self.assign(np.random.normal(size=self.shape)) @@ -85,16 +76,14 @@ class InterpolatorParam: class Interpolator: def __init__(self): self.opts = { - 'global': { - 'time': InterpolatorParam(name='t', value=time.time()) - }, + 'time': InterpolatorParam(name='t', value=time.time()), + 'truncation' : InterpolatorParam(name='truncation', value=1.0), } def build(self): lerp_z = lerp(self.opts, 'latent', [BATCH_SIZE, Z_DIM]) sin_z = sin(self.opts, 'sin_z', [BATCH_SIZE, Z_DIM]) lerp_label = lerp(self.opts, 'label', [BATCH_SIZE, N_CLASS]) - self.opts['truncation'] = InterpolatorParam('truncation', value=1.0) gen_in = {} gen_in['truncation'] = self.opts['truncation'].variable @@ -105,29 +94,24 @@ class Interpolator: # Convert generated image to channels_first. self.gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) - for group, lookup in self.opts.items(): - for key, param in group.items(): - if param.scalar: - param.assign().eval(session=sess) - else: - param.randomize().eval(session=sess) + def get_feed_dict(self): + opt = {} + for key, param in self.opts.items(): + opt[param.variable] = param.value + return opt def get_state(self): opt = {} - for group, lookup in self.opts.items(): - for key, param in group.items(): - if param.scalar: - opt[group][key] = param.value + for key, param in self.opts.items(): + if param.scalar: + opt[key] = param.value return opt def set_value(self, key, value): - self.opts[key].assign(value).eval(session=sess) + return self.opts[key].assign(value) - def on_step(self, i): - gen_time = time.time() - self.opts['global']['time'].assign(gen_time).eval(session=sess) - gen_images = sess.run(self.gen_img) - print("Generation time: {:.1f}s".format(time.time() - gen_time)) + def on_step(self, i, sess): + gen_images = sess.run(self.gen_img, feed_dict=self.get_feed_dict()) return gen_images def run(self, cmd, payload): @@ -136,6 +120,7 @@ class Interpolator: class Listener: def __init__(self): + self.assign_ops = {} self.interpolator = Interpolator() self.interpolator.build() @@ -143,7 +128,7 @@ class Listener: self.rpc_client = CortexRPC(self.on_get, self.on_set, self.on_ready, self.on_cmd) def on_set(self, key, value): - self.interpolator.set_value(key, value) + self.opts[key].assign(value) def on_get(self): return self.interpolator.get_state() @@ -153,20 +138,34 @@ class Listener: self.interpolator.run(cmd, payload) def on_ready(self, rpc_client): + self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) + self.sess.run(tf.global_variables_initializer()) + self.sess.run(tf.tables_initializer()) print("Ready!") self.rpc_client = rpc_client self.rpc_client.send_status('processing', True) for i in range(99999): - gen_images = self.interpolator.on_step(i) + print("Step {}".format(i)) + gen_time = time.time() + self.interpolator.opts['time'].assign(gen_time) + gen_images = self.interpolator.on_step(i, self.sess) if gen_images is None: + print("Exiting...") break + print("Generation time: {:.1f}s".format(time.time() - gen_time)) out_img = vs.data2pil(gen_images[0]) if out_img is not None: - if out_img.resize_before_sending: - out_img.resize((256, 256), Image.BICUBIC) - self.rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, img_to_send, data_opt.output_format) + #if out_img.resize_before_sending: + img_to_send = out_img.resize((256, 256), Image.BICUBIC) + meta = { + 'i': i, + 'sequence_i': i, + 'skip_i': 0, + 'sequence_len': 99999, + } + self.rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, img_to_send, 'jpg') self.rpc_client.send_status('processing', False) - sess.close() + self.sess.close() if __name__ == '__main__': listener = Listener() |
