summaryrefslogtreecommitdiff
path: root/inversion
diff options
context:
space:
mode:
authorjules@lens <julescarbon@gmail.com>2019-12-19 01:48:32 +0100
committerjules@lens <julescarbon@gmail.com>2019-12-19 01:48:32 +0100
commite6e30fa4c0dfedb009ec24d83b9661599a90b4f1 (patch)
treef6d087ae19345984af936cb72bafa746666f3fad /inversion
parent559966019c005175169ed5a68edce9ce8acc0785 (diff)
feed dict is the key
Diffstat (limited to 'inversion')
-rw-r--r--inversion/live.py89
1 files changed, 44 insertions, 45 deletions
diff --git a/inversion/live.py b/inversion/live.py
index 4e42703..cf7f761 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -8,6 +8,7 @@ import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_hub as hub
import time
+from PIL import Image
import visualize as vs
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../live-cortex/rpc/'))
@@ -15,7 +16,7 @@ from rpc import CortexRPC
from params import Params
-params = Params('params_dense.json')
+params = Params('params_dense-512.json')
# --------------------------
# Make directories.
@@ -28,10 +29,7 @@ if not os.path.exists(OUTPUT_DIR):
# --------------------------
# Load Graph.
# --------------------------
-sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
-sess.run(tf.global_variables_initializer())
-sess.run(tf.tables_initializer())
-
+print("Loading module...")
generator = hub.Module(str(params.generator_path))
gen_signature = 'generator'
@@ -39,21 +37,17 @@ if 'generator' not in generator.get_signature_names():
gen_signature = 'default'
input_info = generator.get_input_info_dict(gen_signature)
-COND_GAN = True
BATCH_SIZE = 1
Z_DIM = input_info['z'].get_shape().as_list()[1]
N_CLASS = input_info['y'].get_shape().as_list()[1]
-
def sin(opts, key, shape):
noise = lerp(opts, key + '_noise', shape)
scale = InterpolatorParam(name=key + '_scale')
- time = opts['global']['time'].variable
+ time = opts['time'].variable
out = tf.sin(time + noise) * scale.variable
- opts[key] = {
- 'scale': scale,
- }
+ opts[key + '_scale'] = scale
return out
def lerp(opts, key, shape):
@@ -61,11 +55,9 @@ def lerp(opts, key, shape):
b = InterpolatorParam(name=key + '_b', shape=shape)
n = InterpolatorParam(name=key + '_n')
out = a.variable * (1 - n.variable) + b.variable * n.variable
- opts[key] = {
- 'a': a,
- 'b': b,
- 'n': n,
- }
+ opts[key + '_a'] = a
+ opts[key + '_b'] = b
+ opts[key + '_n'] = n
return out
class InterpolatorParam:
@@ -73,11 +65,10 @@ class InterpolatorParam:
self.scalar = shape == ()
self.shape = shape
self.value = value or np.zeros(shape)
- self.variable = tf.Variable(self.value, name=name, dtype=dtype, shape=shape)
+ self.variable = tf.placeholder(dtype=dtype, shape=shape)
- def assign(value):
+ def assign(self, value):
self.value = value
- return self.variable.assign(value)
def randomize(self):
return self.assign(np.random.normal(size=self.shape))
@@ -85,16 +76,14 @@ class InterpolatorParam:
class Interpolator:
def __init__(self):
self.opts = {
- 'global': {
- 'time': InterpolatorParam(name='t', value=time.time())
- },
+ 'time': InterpolatorParam(name='t', value=time.time()),
+ 'truncation' : InterpolatorParam(name='truncation', value=1.0),
}
def build(self):
lerp_z = lerp(self.opts, 'latent', [BATCH_SIZE, Z_DIM])
sin_z = sin(self.opts, 'sin_z', [BATCH_SIZE, Z_DIM])
lerp_label = lerp(self.opts, 'label', [BATCH_SIZE, N_CLASS])
- self.opts['truncation'] = InterpolatorParam('truncation', value=1.0)
gen_in = {}
gen_in['truncation'] = self.opts['truncation'].variable
@@ -105,29 +94,24 @@ class Interpolator:
# Convert generated image to channels_first.
self.gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
- for group, lookup in self.opts.items():
- for key, param in group.items():
- if param.scalar:
- param.assign().eval(session=sess)
- else:
- param.randomize().eval(session=sess)
+ def get_feed_dict(self):
+ opt = {}
+ for key, param in self.opts.items():
+ opt[param.variable] = param.value
+ return opt
def get_state(self):
opt = {}
- for group, lookup in self.opts.items():
- for key, param in group.items():
- if param.scalar:
- opt[group][key] = param.value
+ for key, param in self.opts.items():
+ if param.scalar:
+ opt[key] = param.value
return opt
def set_value(self, key, value):
- self.opts[key].assign(value).eval(session=sess)
+ return self.opts[key].assign(value)
- def on_step(self, i):
- gen_time = time.time()
- self.opts['global']['time'].assign(gen_time).eval(session=sess)
- gen_images = sess.run(self.gen_img)
- print("Generation time: {:.1f}s".format(time.time() - gen_time))
+ def on_step(self, i, sess):
+ gen_images = sess.run(self.gen_img, feed_dict=self.get_feed_dict())
return gen_images
def run(self, cmd, payload):
@@ -136,6 +120,7 @@ class Interpolator:
class Listener:
def __init__(self):
+ self.assign_ops = {}
self.interpolator = Interpolator()
self.interpolator.build()
@@ -143,7 +128,7 @@ class Listener:
self.rpc_client = CortexRPC(self.on_get, self.on_set, self.on_ready, self.on_cmd)
def on_set(self, key, value):
- self.interpolator.set_value(key, value)
+ self.opts[key].assign(value)
def on_get(self):
return self.interpolator.get_state()
@@ -153,20 +138,34 @@ class Listener:
self.interpolator.run(cmd, payload)
def on_ready(self, rpc_client):
+ self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+ self.sess.run(tf.global_variables_initializer())
+ self.sess.run(tf.tables_initializer())
print("Ready!")
self.rpc_client = rpc_client
self.rpc_client.send_status('processing', True)
for i in range(99999):
- gen_images = self.interpolator.on_step(i)
+ print("Step {}".format(i))
+ gen_time = time.time()
+ self.interpolator.opts['time'].assign(gen_time)
+ gen_images = self.interpolator.on_step(i, self.sess)
if gen_images is None:
+ print("Exiting...")
break
+ print("Generation time: {:.1f}s".format(time.time() - gen_time))
out_img = vs.data2pil(gen_images[0])
if out_img is not None:
- if out_img.resize_before_sending:
- out_img.resize((256, 256), Image.BICUBIC)
- self.rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, img_to_send, data_opt.output_format)
+ #if out_img.resize_before_sending:
+ img_to_send = out_img.resize((256, 256), Image.BICUBIC)
+ meta = {
+ 'i': i,
+ 'sequence_i': i,
+ 'skip_i': 0,
+ 'sequence_len': 99999,
+ }
+ self.rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, img_to_send, 'jpg')
self.rpc_client.send_status('processing', False)
- sess.close()
+ self.sess.close()
if __name__ == '__main__':
listener = Listener()