summaryrefslogtreecommitdiff
path: root/inversion
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-19 16:19:09 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-19 16:19:09 +0100
commitbe56a4688bc4f287699cdb772766e4d6371af1b5 (patch)
tree9f76cb9fc55f6bd5ccfc09236f0e2969ec73e96a /inversion
parent8472f4588d4ddd340826ad1e370443f6819f18bc (diff)
parente6e30fa4c0dfedb009ec24d83b9661599a90b4f1 (diff)
merge
Diffstat (limited to 'inversion')
-rw-r--r--inversion/live.py77
1 files changed, 39 insertions, 38 deletions
diff --git a/inversion/live.py b/inversion/live.py
index b233724..ad4b80f 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -8,6 +8,7 @@ import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_hub as hub
import time
+from PIL import Image
import visualize as vs
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../live-cortex/rpc/'))
@@ -15,7 +16,7 @@ from rpc import CortexRPC
from params import Params
-params = Params('params_dense.json')
+params = Params('params_dense-512.json')
# --------------------------
# Make directories.
@@ -28,10 +29,7 @@ if not os.path.exists(OUTPUT_DIR):
# --------------------------
# Load Graph.
# --------------------------
-sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
-sess.run(tf.global_variables_initializer())
-sess.run(tf.tables_initializer())
-
+print("Loading module...")
generator = hub.Module(str(params.generator_path))
gen_signature = 'generator'
@@ -39,7 +37,6 @@ if 'generator' not in generator.get_signature_names():
gen_signature = 'default'
input_info = generator.get_input_info_dict(gen_signature)
-COND_GAN = True
BATCH_SIZE = 1
Z_DIM = input_info['z'].get_shape().as_list()[1]
@@ -68,9 +65,7 @@ def sin(opts, key, shape):
scale = InterpolatorParam(name=key + '_scale')
time = opts['time'].variable
out = tf.sin(time + noise) * scale.variable
- opts[key] = {
- 'scale': scale,
- }
+ opts[key + '_scale'] = scale
return out
def lerp(opts, key, shape):
@@ -78,11 +73,9 @@ def lerp(opts, key, shape):
b = InterpolatorParam(name=key + '_b', shape=shape)
n = InterpolatorParam(name=key + '_n')
out = a.variable * (1 - n.variable) + b.variable * n.variable
- opts[key] = {
- 'a': a,
- 'b': b,
- 'n': n,
- }
+ opts[key + '_a'] = a
+ opts[key + '_b'] = b
+ opts[key + '_n'] = n
return out
class InterpolatorParam:
@@ -91,11 +84,10 @@ class InterpolatorParam:
self.shape = shape
self.type = type
self.value = value or np.zeros(shape)
- self.variable = tf.Variable(self.value, name=name, dtype=dtype, shape=shape)
+ self.variable = tf.placeholder(dtype=dtype, shape=shape)
- def assign(value):
+ def assign(self, value):
self.value = value
- return self.variable.assign(value)
def randomize(self, num_classes):
if self.type == 'noise':
@@ -107,16 +99,14 @@ class InterpolatorParam:
class Interpolator:
def __init__(self):
self.opts = {
- 'global': {
- 'time': InterpolatorParam(name='t', value=time.time())
- },
+ 'time': InterpolatorParam(name='t', value=time.time()),
+ 'truncation' : InterpolatorParam(name='truncation', value=1.0),
}
def build(self):
lerp_z = lerp(self.opts, 'latent', [BATCH_SIZE, Z_DIM])
sin_z = sin(self.opts, 'sin_z', [BATCH_SIZE, Z_DIM])
lerp_label = lerp(self.opts, 'label', [BATCH_SIZE, N_CLASS])
- self.opts['truncation'] = InterpolatorParam('truncation', value=1.0)
gen_in = {}
gen_in['truncation'] = self.opts['truncation'].variable
@@ -127,12 +117,11 @@ class Interpolator:
# Convert generated image to channels_first.
self.gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
- for group, lookup in self.opts.items():
- for key, param in group.items():
- if param.scalar:
- param.assign().eval(session=sess)
- else:
- param.randomize().eval(session=sess)
+ def get_feed_dict(self):
+ opt = {}
+ for key, param in self.opts.items():
+ opt[param.variable] = param.value
+ return opt
def get_state(self):
opt = {}
@@ -144,11 +133,8 @@ class Interpolator:
def set_value(self, key, value):
self.opts[key].assign(value)
- def on_step(self, i):
- gen_time = time.time()
- self.opts['global']['time'].assign(gen_time).eval(session=sess)
- gen_images = sess.run(self.gen_img)
- print("Generation time: {:.1f}s".format(time.time() - gen_time))
+ def on_step(self, i, sess):
+ gen_images = sess.run(self.gen_img, feed_dict=self.get_feed_dict())
return gen_images
def run(self, cmd, payload):
@@ -157,6 +143,7 @@ class Interpolator:
class Listener:
def __init__(self):
+ self.assign_ops = {}
self.interpolator = Interpolator()
self.interpolator.build()
@@ -164,7 +151,7 @@ class Listener:
self.rpc_client = CortexRPC(self.on_get, self.on_set, self.on_ready, self.on_cmd)
def on_set(self, key, value):
- self.interpolator.set_value(key, value)
+ self.opts[key].assign(value)
def on_get(self):
return self.interpolator.get_state()
@@ -174,20 +161,34 @@ class Listener:
self.interpolator.run(cmd, payload)
def on_ready(self, rpc_client):
+ self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+ self.sess.run(tf.global_variables_initializer())
+ self.sess.run(tf.tables_initializer())
print("Ready!")
self.rpc_client = rpc_client
self.rpc_client.send_status('processing', True)
for i in range(99999):
- gen_images = self.interpolator.on_step(i)
+ print("Step {}".format(i))
+ gen_time = time.time()
+ self.interpolator.opts['time'].assign(gen_time)
+ gen_images = self.interpolator.on_step(i, self.sess)
if gen_images is None:
+ print("Exiting...")
break
+ print("Generation time: {:.1f}s".format(time.time() - gen_time))
out_img = vs.data2pil(gen_images[0])
if out_img is not None:
- if out_img.resize_before_sending:
- out_img.resize((256, 256), Image.BICUBIC)
- self.rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, img_to_send, data_opt.output_format)
+ #if out_img.resize_before_sending:
+ img_to_send = out_img.resize((256, 256), Image.BICUBIC)
+ meta = {
+ 'i': i,
+ 'sequence_i': i,
+ 'skip_i': 0,
+ 'sequence_len': 99999,
+ }
+ self.rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, img_to_send, 'jpg')
self.rpc_client.send_status('processing', False)
- sess.close()
+ self.sess.close()
if __name__ == '__main__':
listener = Listener()