diff options
Diffstat (limited to 'inversion/live.py')
| -rw-r--r-- | inversion/live.py | 129 |
1 files changed, 87 insertions, 42 deletions
diff --git a/inversion/live.py b/inversion/live.py index 5bd26fe..58cfa1e 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -50,6 +50,17 @@ Z_DIM = input_info['z'].get_shape().as_list()[1] N_CLASS = input_info['y'].get_shape().as_list()[1] # -------------------------- +# Utils +# -------------------------- + +def clamp(n, a=0, b=1): + if n < a: + return a + if n > b: + return b + return n + +# -------------------------- # Initializers # -------------------------- @@ -73,27 +84,51 @@ def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)): # More complex ops # -------------------------- -def sin(opts, key, shape, type=type): - noise = lerp(opts, key + '_noise', shape, type=type) - scale = InterpolatorParam(name=key + '_scale', value=0.1) - speed = InterpolatorParam(name=key + '_speed', value=1.0) - time = opts['time'].variable - out = tf.math.sin(time + noise) * scale.variable - opts[key + '_scale'] = scale - opts[key + '_speed'] = speed - return out +class SinParam: + def __init__(self, key, shape, type): + noise = LerpParam(key + '_noise', shape, type=type) + orbit_radius = InterpolatorParam(name=key + '_radius', value=0.1) + orbit_speed = InterpolatorParam(name=key + '_speed', value=0.1) + orbit_time = InterpolatorParam(name=key + '_time', value=0.0) + output = tf.math.sin(orbit_time.variable + noise.output) * orbit_radius.variable + interpolator.sin_params[key] = self + self.key = key + self.orbit_speed = orbit_speed + self.orbit_time = orbit_time + self.output = output + + def update(self, dt): + self.orbit_time.value += self.orbit_speed.value * dt + +class LerpParam: + def __init__(self, key, shape, type="noise"): + a = InterpolatorParam(name=key + '_a', shape=shape, type=type) + b = InterpolatorParam(name=key + '_b', shape=shape, type=type) + n = InterpolatorParam(name=key + '_n', value=0.0) + speed = InterpolatorParam(name=key + '_speed', value=0.1) + output = a.variable * (1 - n.variable) + b.variable * n.variable + interpolator.lerp_params[key] = self + self.key = key + self.a = a + self.b = b + self.n = n + self.speed = speed + self.output = output + self.direction = 0 -def lerp(opts, key, shape, type='noise'): - a = InterpolatorParam(name=key + '_a', shape=shape, type=type) - b = InterpolatorParam(name=key + '_b', shape=shape, type=type) - n = InterpolatorParam(name=key + '_n', value=0.0) - speed = InterpolatorParam(name=key + '_speed', value=0.1) - out = a.variable * (1 - n.variable) + b.variable * n.variable - opts[key + '_a'] = a - opts[key + '_b'] = b - opts[key + '_n'] = n - opts[key + '_speed'] = speed - return out + def switch(self): + if self.n > 0.5: + self.a.randomize() + self.direction = -1 + else: + self.b.randomize() + self.direction = 1 + + def update(self, dt): + if self.direction != 0: + self.n.value = clamp(self.n.value + self.direction * self.speed.value * dt) + if self.n.value == 0 or self.n.value == 1: + self.direction = 0 # -------------------------- # Placeholder params @@ -106,15 +141,16 @@ class InterpolatorParam: self.type = type self.value = value if value is not None else np.zeros(shape) self.variable = tf.placeholder(dtype=dtype, shape=shape) + interpolator.opts[key] = self def assign(self, value): self.value = value - def randomize(self, num_classes=1, truncation=1.0): + def randomize(self): if self.type == 'noise': - val = truncated_z_sample(shape=self.shape, truncation=truncation) + val = truncated_z_sample(shape=self.shape, truncation=interpolator.opt['truncation'].value) elif self.type == 'label': - val = label_sampler(shape=self.shape, num_classes=num_classes) + val = label_sampler(shape=self.shape, num_classes=interpolator.opt['num_classes'].value) self.assign(val) # -------------------------- @@ -124,24 +160,31 @@ class InterpolatorParam: class Interpolator: def __init__(self): self.opts = { - 'time': InterpolatorParam(name='t', value=time.time()), + # 'time': InterpolatorParam(name='t', value=time.time()), 'truncation' : InterpolatorParam(name='truncation', value=1.0), + 'num_classes' : InterpolatorParam(name='num_classes', value=1.0), } + self.sin_params = {} + self.lerp_params = {} def build(self): - lerp_z = lerp(self.opts, 'latent', [BATCH_SIZE, Z_DIM]) - sin_z = sin(self.opts, 'orbit', [BATCH_SIZE, Z_DIM]) - lerp_label = lerp(self.opts, 'label', [BATCH_SIZE, N_CLASS]) + lerp_z = LerpParam('latent', [BATCH_SIZE, Z_DIM]) + sin_z = SinParam('orbit', [BATCH_SIZE, Z_DIM]) + lerp_label = LerpParam('label', [BATCH_SIZE, N_CLASS]) # self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], type='noise') # self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], type='label') gen_in = {} gen_in['truncation'] = 1.0 # self.opts['truncation'].variable - gen_in['z'] = lerp_z + sin_z - gen_in['y'] = lerp_label + gen_in['z'] = lerp_z.output + sin_z.output + gen_in['y'] = lerp_label.output self.gen_img = generator(gen_in, signature=gen_signature) + sys.stderr.write("Sin params: {}".format(", ".join(self.sin_params.keys()))) + sys.stderr.write("Lerp params: {}".format(", ".join(self.lerp_params.keys()))) + sys.stderr.write("Opts: {}".format(", ".join(self.opts.keys()))) + # Convert generated image to channels_first. # self.gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) @@ -168,46 +211,48 @@ class Interpolator: else: sys.stderr.write('{} not a valid option\n'.format(key)) - def on_step(self, i, sess): - # self.opts['z'].randomize(truncation=1.0) - # self.opts['y'].randomize(num_classes=1.0) + def on_step(self, i, dt, sess): + for param in self.sin_params.values(): + param.update(dt) + for param in self.lerp_params.values(): + param.update(dt) gen_images = sess.run(self.gen_img, feed_dict=self.get_feed_dict()) return gen_images def run(self, cmd, payload): - # do things like create a new B and interpolate to it + if cmd == 'switch' and payload in self.lerp_params: + self.lerp_params[payload].switch() pass # -------------------------- # RPC Listener # -------------------------- -class Listener: - def __init__(self): - self.interpolator = Interpolator() - self.interpolator.build() +interpolator = Interpolator() +class Listener: def connect(self): self.rpc_client = CortexRPC(self.on_get, self.on_set, self.on_ready, self.on_cmd) def on_set(self, key, value): - self.interpolator.set_value(key, value) + interpolator.set_value(key, value) def on_get(self): - state = self.interpolator.get_state() + state = interpolator.get_state() sys.stderr.write(json.dumps(state) + "\n") sys.stderr.flush() return state def on_cmd(self, cmd, payload): print("got command {}".format(cmd)) - self.interpolator.run(cmd, payload) + interpolator.run(cmd, payload) def on_ready(self, rpc_client): self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.tables_initializer()) print("Ready!") + interpolator.build() self.rpc_client = rpc_client self.rpc_client.send_status('processing', True) dt = 1 / FPS @@ -215,8 +260,8 @@ class Listener: if (i % 100) == 0: print("Step {}".format(i)) gen_time = time.time() - self.interpolator.opts['time'].assign(gen_time) - gen_images = self.interpolator.on_step(i, self.sess) + interpolator.opts['time'].assign(gen_time) + gen_images = interpolator.on_step(i, dt, self.sess) if gen_images is None: print("Exiting...") break |
