summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--inversion/live.py32
1 files changed, 16 insertions, 16 deletions
diff --git a/inversion/live.py b/inversion/live.py
index 38c98b6..08cc789 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -85,14 +85,14 @@ def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)):
# --------------------------
class SinParam:
- def __init__(self, key, shape, type):
- noise = LerpParam(key + '_noise', shape, type=type)
- orbit_radius = InterpolatorParam(name=key + '_radius', value=0.1)
- orbit_speed = InterpolatorParam(name=key + '_speed', value=0.1)
- orbit_time = InterpolatorParam(name=key + '_time', value=0.0)
+ def __init__(self, name, shape, type):
+ noise = LerpParam(name + '_noise', shape, type=type)
+ orbit_radius = InterpolatorParam(name=name + '_radius', value=0.1)
+ orbit_speed = InterpolatorParam(name=name + '_speed', value=0.1)
+ orbit_time = InterpolatorParam(name=name + '_time', value=0.0)
output = tf.math.sin(orbit_time.variable + noise.output) * orbit_radius.variable
- interpolator.sin_params[key] = self
- self.key = key
+ interpolator.sin_params[name] = self
+ self.name = name
self.orbit_speed = orbit_speed
self.orbit_time = orbit_time
self.output = output
@@ -101,14 +101,14 @@ class SinParam:
self.orbit_time.value += self.orbit_speed.value * dt
class LerpParam:
- def __init__(self, key, shape, type="noise"):
- a = InterpolatorParam(name=key + '_a', shape=shape, type=type)
- b = InterpolatorParam(name=key + '_b', shape=shape, type=type)
- n = InterpolatorParam(name=key + '_n', value=0.0)
- speed = InterpolatorParam(name=key + '_speed', value=0.1)
+ def __init__(self, name, shape, type="noise"):
+ a = InterpolatorParam(name=name + '_a', shape=shape, type=type)
+ b = InterpolatorParam(name=name + '_b', shape=shape, type=type)
+ n = InterpolatorParam(name=name + '_n', value=0.0)
+ speed = InterpolatorParam(name=name + '_speed', value=0.1)
output = a.variable * (1 - n.variable) + b.variable * n.variable
- interpolator.lerp_params[key] = self
- self.key = key
+ interpolator.lerp_params[name] = self
+ self.name = name
self.a = a
self.b = b
self.n = n
@@ -141,7 +141,7 @@ class InterpolatorParam:
self.type = type
self.value = value if value is not None else np.zeros(shape)
self.variable = tf.placeholder(dtype=dtype, shape=shape)
- interpolator.opts[key] = self
+ interpolator.opts[name] = self
def assign(self, value):
self.value = value
@@ -188,7 +188,7 @@ class Interpolator:
def get_feed_dict(self):
opt = {}
- for key, param in self.opts.items():
+ for param in self.opts.values():
opt[param.variable] = param.value
return opt