summaryrefslogtreecommitdiff
path: root/cli/app/search/live.py
diff options
context:
space:
mode:
Diffstat (limited to 'cli/app/search/live.py')
-rw-r--r--cli/app/search/live.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/cli/app/search/live.py b/cli/app/search/live.py
index b52ec03..63799a1 100644
--- a/cli/app/search/live.py
+++ b/cli/app/search/live.py
@@ -220,6 +220,7 @@ class Interpolator:
gen_layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer
encoding_latent = tf.get_default_graph().get_tensor_by_name(gen_layer_name)
+ print(encoding_latent.get_shape())
encoding_shape = [1,] + encoding_latent.get_shape().as_list()[1:]
print(encoding_shape)
encoding_shape_placeholder = tf.constant(np.zeros(encoding_shape, dtype=np.float32))
@@ -272,20 +273,21 @@ class Interpolator:
def set_encoding(self, opt):
next_id = opt['id']
data = load_pickle(os.path.join(app_cfg.DIR_VECTORS, "file_{}.pkl".format(next_id)))
+ encoding = np.expand_dims(data['encoding'], axis=0)
encoding_stored = self.lerp_params['encoding_stored']
encoding_mix = self.lerp_params['encoding_mix']
# if we're showing an encoding already, lerp to the next one
if encoding_mix.n.value > 0:
- encoding_stored.switch(target_value=data['encoding'])
+ encoding_stored.switch(target_value=encoding)
# otherwise (we're showing the latent)...
else:
# jump to the stored encoding, then switch
if encoding_stored.n.value < 0.5:
encoding_stored.n.value = 0
- encoding_stored.a.assign(data['encoding'])
+ encoding_stored.a.assign(encoding)
else:
encoding_stored.n.value = 1
- encoding_stored.b.assign(data['encoding'])
+ encoding_stored.b.assign(encoding)
encoding_mix.switch()
def on_step(self, i, dt, sess):