summaryrefslogtreecommitdiff
path: root/inversion
diff options
context:
space:
mode:
Diffstat (limited to 'inversion')
-rw-r--r--inversion/live.py25
1 files changed, 20 insertions, 5 deletions
diff --git a/inversion/live.py b/inversion/live.py
index b07913a..a980e70 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -103,8 +103,12 @@ class SinParam:
class LerpParam:
def __init__(self, name, shape, a_in=None, b_in=None, datatype="noise"):
- a = a_in or InterpolatorParam(name=name + '_a', shape=shape, datatype=datatype)
- b = b_in or InterpolatorParam(name=name + '_b', shape=shape, datatype=datatype)
+ if a_in and b_in:
+ a = InterpolatorVariable(variable=a_in)
+ b = InterpolatorVariable(variable=b_in)
+ else:
+ a = InterpolatorParam(name=name + '_a', shape=shape, datatype=datatype)
+ b = InterpolatorParam(name=name + '_b', shape=shape, datatype=datatype)
n = InterpolatorParam(name=name + '_n', value=0.0)
speed = InterpolatorParam(name=name + '_speed', value=0.1)
output = a.variable * (1 - n.variable) + b.variable * n.variable
@@ -162,6 +166,17 @@ class InterpolatorParam:
val = 0.0
self.assign(val)
+class InterpolatorVariable:
+ def __init__(self, variable):
+ self.scalar = False
+ self.variable = variable
+
+ def assign(self):
+ pass
+
+ def randomize(self):
+ pass
+
# --------------------------
# Interpolator graph
# --------------------------
@@ -173,9 +188,9 @@ class Interpolator:
self.lerp_params = {}
def build(self):
- InterpolatorParam(name='truncation', value=1.0),
- InterpolatorParam(name='num_classes', value=1.0),
- abs_zoom = InterpolatorParam(name='abs_zoom', value=0.0),
+ InterpolatorParam(name='truncation', value=1.0)
+ InterpolatorParam(name='num_classes', value=1.0)
+ abs_zoom = InterpolatorParam(name='abs_zoom', value=1.0)
lerp_z = LerpParam('latent', shape=[BATCH_SIZE, Z_DIM], datatype="noise")
sin_z = SinParam('orbit', shape=[BATCH_SIZE, Z_DIM], datatype="noise")
lerp_label = LerpParam('label', shape=[BATCH_SIZE, N_CLASS], datatype="label")