From 6f2c13eb5e35f67153875b2d98ea50db9613f11d Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Thu, 26 Dec 2019 16:56:59 +0100 Subject: lerp to cat --- inversion/live.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'inversion/live.py') diff --git a/inversion/live.py b/inversion/live.py index 64f400d..2b965c5 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -243,11 +243,12 @@ class Interpolator: def set_category(self, category): categories = " ".split(category) label = np.zeros((BATCH_SIZE, N_CLASS,)) - ratio = 1 / len(categories) - target_param = self.lerp_params[payload].switch() + target_param = self.lerp_params['label'].switch() for category in categories: index = int(category) - label[0, index] = ratio + if index > 0 and index < N_CLASS: + label[0, index] = 1.0 + label[0] /= label[0].sum() target_param.assign(label) def on_step(self, i, dt, sess): @@ -292,17 +293,17 @@ class Listener: def on_ready(self, rpc_client): self.rpc_client = rpc_client - print("Starting session") + print("Starting session...") self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.tables_initializer()) - print("Building interpolator") + print("Building interpolator...") interpolator.build() self.rpc_client.send_status('processing', True) dt = 1 / FPS for i in range(99999): if i == 0: - print("Loading network") + print("Loading network...") elif i == 1: print("Processing!") gen_time = time.time() -- cgit v1.2.3-70-g09d2