summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--inversion/live.py129
1 files changed, 87 insertions, 42 deletions
diff --git a/inversion/live.py b/inversion/live.py
index 5bd26fe..58cfa1e 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -50,6 +50,17 @@ Z_DIM = input_info['z'].get_shape().as_list()[1]
N_CLASS = input_info['y'].get_shape().as_list()[1]
# --------------------------
+# Utils
+# --------------------------
+
+def clamp(n, a=0, b=1):
+ if n < a:
+ return a
+ if n > b:
+ return b
+ return n
+
+# --------------------------
# Initializers
# --------------------------
@@ -73,27 +84,51 @@ def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)):
# More complex ops
# --------------------------
-def sin(opts, key, shape, type=type):
- noise = lerp(opts, key + '_noise', shape, type=type)
- scale = InterpolatorParam(name=key + '_scale', value=0.1)
- speed = InterpolatorParam(name=key + '_speed', value=1.0)
- time = opts['time'].variable
- out = tf.math.sin(time + noise) * scale.variable
- opts[key + '_scale'] = scale
- opts[key + '_speed'] = speed
- return out
+class SinParam:
+ def __init__(self, key, shape, type):
+ noise = LerpParam(key + '_noise', shape, type=type)
+ orbit_radius = InterpolatorParam(name=key + '_radius', value=0.1)
+ orbit_speed = InterpolatorParam(name=key + '_speed', value=0.1)
+ orbit_time = InterpolatorParam(name=key + '_time', value=0.0)
+ output = tf.math.sin(orbit_time.variable + noise.output) * orbit_radius.variable
+ interpolator.sin_params[key] = self
+ self.key = key
+ self.orbit_speed = orbit_speed
+ self.orbit_time = orbit_time
+ self.output = output
+
+ def update(self, dt):
+ self.orbit_time.value += self.orbit_speed.value * dt
+
+class LerpParam:
+ def __init__(self, key, shape, type="noise"):
+ a = InterpolatorParam(name=key + '_a', shape=shape, type=type)
+ b = InterpolatorParam(name=key + '_b', shape=shape, type=type)
+ n = InterpolatorParam(name=key + '_n', value=0.0)
+ speed = InterpolatorParam(name=key + '_speed', value=0.1)
+ output = a.variable * (1 - n.variable) + b.variable * n.variable
+ interpolator.lerp_params[key] = self
+ self.key = key
+ self.a = a
+ self.b = b
+ self.n = n
+ self.speed = speed
+ self.output = output
+ self.direction = 0
-def lerp(opts, key, shape, type='noise'):
- a = InterpolatorParam(name=key + '_a', shape=shape, type=type)
- b = InterpolatorParam(name=key + '_b', shape=shape, type=type)
- n = InterpolatorParam(name=key + '_n', value=0.0)
- speed = InterpolatorParam(name=key + '_speed', value=0.1)
- out = a.variable * (1 - n.variable) + b.variable * n.variable
- opts[key + '_a'] = a
- opts[key + '_b'] = b
- opts[key + '_n'] = n
- opts[key + '_speed'] = speed
- return out
+ def switch(self):
+ if self.n > 0.5:
+ self.a.randomize()
+ self.direction = -1
+ else:
+ self.b.randomize()
+ self.direction = 1
+
+ def update(self, dt):
+ if self.direction != 0:
+ self.n.value = clamp(self.n.value + self.direction * self.speed.value * dt)
+ if self.n.value == 0 or self.n.value == 1:
+ self.direction = 0
# --------------------------
# Placeholder params
@@ -106,15 +141,16 @@ class InterpolatorParam:
self.type = type
self.value = value if value is not None else np.zeros(shape)
self.variable = tf.placeholder(dtype=dtype, shape=shape)
+ interpolator.opts[key] = self
def assign(self, value):
self.value = value
- def randomize(self, num_classes=1, truncation=1.0):
+ def randomize(self):
if self.type == 'noise':
- val = truncated_z_sample(shape=self.shape, truncation=truncation)
+ val = truncated_z_sample(shape=self.shape, truncation=interpolator.opt['truncation'].value)
elif self.type == 'label':
- val = label_sampler(shape=self.shape, num_classes=num_classes)
+ val = label_sampler(shape=self.shape, num_classes=interpolator.opt['num_classes'].value)
self.assign(val)
# --------------------------
@@ -124,24 +160,31 @@ class InterpolatorParam:
class Interpolator:
def __init__(self):
self.opts = {
- 'time': InterpolatorParam(name='t', value=time.time()),
+ # 'time': InterpolatorParam(name='t', value=time.time()),
'truncation' : InterpolatorParam(name='truncation', value=1.0),
+ 'num_classes' : InterpolatorParam(name='num_classes', value=1.0),
}
+ self.sin_params = {}
+ self.lerp_params = {}
def build(self):
- lerp_z = lerp(self.opts, 'latent', [BATCH_SIZE, Z_DIM])
- sin_z = sin(self.opts, 'orbit', [BATCH_SIZE, Z_DIM])
- lerp_label = lerp(self.opts, 'label', [BATCH_SIZE, N_CLASS])
+ lerp_z = LerpParam('latent', [BATCH_SIZE, Z_DIM])
+ sin_z = SinParam('orbit', [BATCH_SIZE, Z_DIM])
+ lerp_label = LerpParam('label', [BATCH_SIZE, N_CLASS])
# self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], type='noise')
# self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], type='label')
gen_in = {}
gen_in['truncation'] = 1.0 # self.opts['truncation'].variable
- gen_in['z'] = lerp_z + sin_z
- gen_in['y'] = lerp_label
+ gen_in['z'] = lerp_z.output + sin_z.output
+ gen_in['y'] = lerp_label.output
self.gen_img = generator(gen_in, signature=gen_signature)
+ sys.stderr.write("Sin params: {}".format(", ".join(self.sin_params.keys())))
+ sys.stderr.write("Lerp params: {}".format(", ".join(self.lerp_params.keys())))
+ sys.stderr.write("Opts: {}".format(", ".join(self.opts.keys())))
+
# Convert generated image to channels_first.
# self.gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
@@ -168,46 +211,48 @@ class Interpolator:
else:
sys.stderr.write('{} not a valid option\n'.format(key))
- def on_step(self, i, sess):
- # self.opts['z'].randomize(truncation=1.0)
- # self.opts['y'].randomize(num_classes=1.0)
+ def on_step(self, i, dt, sess):
+ for param in self.sin_params.values():
+ param.update(dt)
+ for param in self.lerp_params.values():
+ param.update(dt)
gen_images = sess.run(self.gen_img, feed_dict=self.get_feed_dict())
return gen_images
def run(self, cmd, payload):
- # do things like create a new B and interpolate to it
+ if cmd == 'switch' and payload in self.lerp_params:
+ self.lerp_params[payload].switch()
pass
# --------------------------
# RPC Listener
# --------------------------
-class Listener:
- def __init__(self):
- self.interpolator = Interpolator()
- self.interpolator.build()
+interpolator = Interpolator()
+class Listener:
def connect(self):
self.rpc_client = CortexRPC(self.on_get, self.on_set, self.on_ready, self.on_cmd)
def on_set(self, key, value):
- self.interpolator.set_value(key, value)
+ interpolator.set_value(key, value)
def on_get(self):
- state = self.interpolator.get_state()
+ state = interpolator.get_state()
sys.stderr.write(json.dumps(state) + "\n")
sys.stderr.flush()
return state
def on_cmd(self, cmd, payload):
print("got command {}".format(cmd))
- self.interpolator.run(cmd, payload)
+ interpolator.run(cmd, payload)
def on_ready(self, rpc_client):
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
self.sess.run(tf.global_variables_initializer())
self.sess.run(tf.tables_initializer())
print("Ready!")
+ interpolator.build()
self.rpc_client = rpc_client
self.rpc_client.send_status('processing', True)
dt = 1 / FPS
@@ -215,8 +260,8 @@ class Listener:
if (i % 100) == 0:
print("Step {}".format(i))
gen_time = time.time()
- self.interpolator.opts['time'].assign(gen_time)
- gen_images = self.interpolator.on_step(i, self.sess)
+ interpolator.opts['time'].assign(gen_time)
+ gen_images = interpolator.on_step(i, dt, self.sess)
if gen_images is None:
print("Exiting...")
break