summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--inversion/live.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/inversion/live.py b/inversion/live.py
index 2b965c5..97078ae 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -121,15 +121,17 @@ class LerpParam:
self.output = output
self.direction = 0
- def switch(self):
+ def switch(self, target_value=None):
if self.n.value > 0.5:
- self.a.randomize()
+ target_param = self.a
self.direction = -1
- return self.a
else:
- self.b.randomize()
+ target_param = self.b
self.direction = 1
- return self.b
+ if target_value is None:
+ target_param.randomize()
+ else:
+ target_param.assign(target_value)
def update(self, dt):
if self.direction != 0:
@@ -241,15 +243,15 @@ class Interpolator:
sys.stderr.write('{} not a valid option\n'.format(key))
def set_category(self, category):
+ print("Set category: {}".format(category))
categories = " ".split(category)
label = np.zeros((BATCH_SIZE, N_CLASS,))
- target_param = self.lerp_params['label'].switch()
for category in categories:
index = int(category)
if index > 0 and index < N_CLASS:
label[0, index] = 1.0
label[0] /= label[0].sum()
- target_param.assign(label)
+ self.lerp_params['label'].switch(target_value=label)
def on_step(self, i, dt, sess):
for param in self.sin_params.values():