diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2019-12-26 17:26:16 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2019-12-26 17:26:16 +0100 |
| commit | 49eca5f936cea9900f878018e9c76ba4b752f447 (patch) | |
| tree | 9d16fee9846f861f63ac03d44f6ed773361ecd26 | |
| parent | 6f2c13eb5e35f67153875b2d98ea50db9613f11d (diff) | |
alt method
| -rw-r--r-- | inversion/live.py | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/inversion/live.py b/inversion/live.py index 2b965c5..97078ae 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -121,15 +121,17 @@ class LerpParam: self.output = output self.direction = 0 - def switch(self): + def switch(self, target_value=None): if self.n.value > 0.5: - self.a.randomize() + target_param = self.a self.direction = -1 - return self.a else: - self.b.randomize() + target_param = self.b self.direction = 1 - return self.b + if target_value is None: + target_param.randomize() + else: + target_param.assign(target_value) def update(self, dt): if self.direction != 0: @@ -241,15 +243,15 @@ class Interpolator: sys.stderr.write('{} not a valid option\n'.format(key)) def set_category(self, category): + print("Set category: {}".format(category)) categories = " ".split(category) label = np.zeros((BATCH_SIZE, N_CLASS,)) - target_param = self.lerp_params['label'].switch() for category in categories: index = int(category) if index > 0 and index < N_CLASS: label[0, index] = 1.0 label[0] /= label[0].sum() - target_param.assign(label) + self.lerp_params['label'].switch(target_value=label) def on_step(self, i, dt, sess): for param in self.sin_params.values(): |
