summaryrefslogtreecommitdiff
path: root/inversion
diff options
context:
space:
mode:
Diffstat (limited to 'inversion')
-rw-r--r--inversion/live.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/inversion/live.py b/inversion/live.py
index b5ac0c3..b07913a 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -102,9 +102,9 @@ class SinParam:
self.orbit_time.value += self.orbit_speed.value * dt
class LerpParam:
- def __init__(self, name, shape, datatype="noise"):
- a = InterpolatorParam(name=name + '_a', shape=shape, datatype=datatype)
- b = InterpolatorParam(name=name + '_b', shape=shape, datatype=datatype)
+ def __init__(self, name, shape, a_in=None, b_in=None, datatype="noise"):
+ a = a_in or InterpolatorParam(name=name + '_a', shape=shape, datatype=datatype)
+ b = b_in or InterpolatorParam(name=name + '_b', shape=shape, datatype=datatype)
n = InterpolatorParam(name=name + '_n', value=0.0)
speed = InterpolatorParam(name=name + '_speed', value=0.1)
output = a.variable * (1 - n.variable) + b.variable * n.variable
@@ -175,16 +175,21 @@ class Interpolator:
def build(self):
InterpolatorParam(name='truncation', value=1.0),
InterpolatorParam(name='num_classes', value=1.0),
+ abs_zoom = InterpolatorParam(name='abs_zoom', value=0.0),
lerp_z = LerpParam('latent', shape=[BATCH_SIZE, Z_DIM], datatype="noise")
sin_z = SinParam('orbit', shape=[BATCH_SIZE, Z_DIM], datatype="noise")
lerp_label = LerpParam('label', shape=[BATCH_SIZE, N_CLASS], datatype="label")
+ z_sum = lerp_z.output + sin_z.output
+ z_abs = z_sum / tf.abs(z_sum) * abs_zoom.variable
+ z_mix = LerpParam('abs_mix', a_in=z_sum, b_in=z_mix, shape=[BATCH_SIZE, Z_DIM], datatype="input")
+
# self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], datatype='noise')
# self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], datatype='label')
gen_in = {}
gen_in['truncation'] = 1.0 # self.opts['truncation'].variable
- gen_in['z'] = lerp_z.output + sin_z.output
+ gen_in['z'] = z_mix
gen_in['y'] = lerp_label.output
self.gen_img = generator(gen_in, signature=gen_signature)