diff options
Diffstat (limited to 'inversion')
| -rw-r--r-- | inversion/live.py | 25 |
1 files changed, 20 insertions, 5 deletions
diff --git a/inversion/live.py b/inversion/live.py index b07913a..a980e70 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -103,8 +103,12 @@ class SinParam: class LerpParam: def __init__(self, name, shape, a_in=None, b_in=None, datatype="noise"): - a = a_in or InterpolatorParam(name=name + '_a', shape=shape, datatype=datatype) - b = b_in or InterpolatorParam(name=name + '_b', shape=shape, datatype=datatype) + if a_in and b_in: + a = InterpolatorVariable(variable=a_in) + b = InterpolatorVariable(variable=b_in) + else: + a = InterpolatorParam(name=name + '_a', shape=shape, datatype=datatype) + b = InterpolatorParam(name=name + '_b', shape=shape, datatype=datatype) n = InterpolatorParam(name=name + '_n', value=0.0) speed = InterpolatorParam(name=name + '_speed', value=0.1) output = a.variable * (1 - n.variable) + b.variable * n.variable @@ -162,6 +166,17 @@ class InterpolatorParam: val = 0.0 self.assign(val) +class InterpolatorVariable: + def __init__(self, variable): + self.scalar = False + self.variable = variable + + def assign(self): + pass + + def randomize(self): + pass + # -------------------------- # Interpolator graph # -------------------------- @@ -173,9 +188,9 @@ class Interpolator: self.lerp_params = {} def build(self): - InterpolatorParam(name='truncation', value=1.0), - InterpolatorParam(name='num_classes', value=1.0), - abs_zoom = InterpolatorParam(name='abs_zoom', value=0.0), + InterpolatorParam(name='truncation', value=1.0) + InterpolatorParam(name='num_classes', value=1.0) + abs_zoom = InterpolatorParam(name='abs_zoom', value=1.0) lerp_z = LerpParam('latent', shape=[BATCH_SIZE, Z_DIM], datatype="noise") sin_z = SinParam('orbit', shape=[BATCH_SIZE, Z_DIM], datatype="noise") lerp_label = LerpParam('label', shape=[BATCH_SIZE, N_CLASS], datatype="label") |
