summaryrefslogtreecommitdiff
path: root/inversion
diff options
context:
space:
mode:
authorjules@lens <julescarbon@gmail.com>2019-12-18 23:52:30 +0100
committerjules@lens <julescarbon@gmail.com>2019-12-18 23:52:30 +0100
commit284d899063fb9401414e302e21966d1ac1b7c0ff (patch)
treeac6aaf8948ff883180d4d2867b45abfc45134bc7 /inversion
parent937c7d6431d6a990b959aabb5fb2e65824fcd4c0 (diff)
fix
Diffstat (limited to 'inversion')
-rw-r--r--inversion/live.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/inversion/live.py b/inversion/live.py
index 0fdad7c..0794028 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -51,7 +51,7 @@ def sin(opts, key, shape):
noise = lerp(opts, key + '_noise', shape)
scale = InterpolatorParam(name=key + '_scale')
time = opts['global']['time'].variable
- out = tf.sin(time + noise) * scale
+ out = tf.sin(time + noise) * scale.variable
opts[key] = {
'scale': scale,
}
@@ -61,7 +61,7 @@ def lerp(opts, key, shape):
a = InterpolatorParam(name=key + '_a', shape=shape)
b = InterpolatorParam(name=key + '_b', shape=shape)
n = InterpolatorParam(name=key + '_n')
- out = a * (1 - n) + b * n
+ out = a.variable * (1 - n.variable) + b.variable * n.variable
opts[key] = {
'a': a,
'b': b,
@@ -97,10 +97,10 @@ class Interpolator:
lerp_z = lerp(opts, 'latent', [BATCH_SIZE, Z_DIM])
sin_z = sin(opts, 'sin_z', [BATCH_SIZE, Z_DIM])
lerp_label = lerp(opts, 'label', [BATCH_SIZE, N_CLASS])
- opts['threshold'] = InterpolatorParam('threshold', value=1.0)
+ opts['truncation'] = InterpolatorParam('truncation', value=1.0)
gen_in = {}
- gen_in['threshold'] = opts['threshold'].variable
+ gen_in['truncation'] = opts['truncation'].variable
gen_in['z'] = lerp_z + sin_z
gen_in['y'] = lerp_label
gen_img = generator(gen_in, signature=gen_signature)
@@ -126,14 +126,14 @@ class Interpolator:
def set_value(self, key, value):
self.opts[key].assign(value).eval(session=sess)
- def on_step(i):
+ def on_step(self, i):
gen_time = time.time()
self.opts['global']['time'].assign(gen_time).eval(session=sess)
gen_images = sess.run(self.gen_img)
print("Generation time: {:.1f}s".format(time.time() - gen_time))
return gen_images
- def run(cmd, payload):
+ def run(self, cmd, payload):
# do things like create a new B and interpolate to it
pass