summaryrefslogtreecommitdiff
path: root/inversion/live.py
diff options
context:
space:
mode:
Diffstat (limited to 'inversion/live.py')
-rw-r--r--inversion/live.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/inversion/live.py b/inversion/live.py
index 64f400d..2b965c5 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -243,11 +243,12 @@ class Interpolator:
def set_category(self, category):
categories = " ".split(category)
label = np.zeros((BATCH_SIZE, N_CLASS,))
- ratio = 1 / len(categories)
- target_param = self.lerp_params[payload].switch()
+ target_param = self.lerp_params['label'].switch()
for category in categories:
index = int(category)
- label[0, index] = ratio
+ if index > 0 and index < N_CLASS:
+ label[0, index] = 1.0
+ label[0] /= label[0].sum()
target_param.assign(label)
def on_step(self, i, dt, sess):
@@ -292,17 +293,17 @@ class Listener:
def on_ready(self, rpc_client):
self.rpc_client = rpc_client
- print("Starting session")
+ print("Starting session...")
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
self.sess.run(tf.global_variables_initializer())
self.sess.run(tf.tables_initializer())
- print("Building interpolator")
+ print("Building interpolator...")
interpolator.build()
self.rpc_client.send_status('processing', True)
dt = 1 / FPS
for i in range(99999):
if i == 0:
- print("Loading network")
+ print("Loading network...")
elif i == 1:
print("Processing!")
gen_time = time.time()