summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-01-10 16:21:26 +0100
committerJules Laplace <julescarbon@gmail.com>2020-01-10 16:21:26 +0100
commitf462fe520aca4fd15127fd6d0b27e342e2f23a14 (patch)
treef0194a31c4bf6f278b7e155cdb19d3c38bc9faed
parent258dd9a3634239a6bab64620b408f023c46fdc4a (diff)
graph magic
-rw-r--r--cli/app/search/live.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/cli/app/search/live.py b/cli/app/search/live.py
index 80aafc6..21a9c5a 100644
--- a/cli/app/search/live.py
+++ b/cli/app/search/live.py
@@ -225,7 +225,7 @@ class Interpolator:
encoding_shape_placeholder = tf.constant(np.zeros(encoding_shape_specific, dtype=np.float32))
encoding_stored = LerpParam('encoding_stored', shape=encoding_shape_specific, datatype="encoding")
- encoding_mix = LerpParam('encoding_mix', a_in=encoding_stored.output, b_in=encoding_shape_placeholder, shape=encoding_shape_specific, datatype="encoding")
+ encoding_mix = LerpParam('encoding_mix', a_in=encoding_shape_placeholder, b_in=encoding_stored.output, shape=encoding_shape_specific, datatype="encoding")
# use the placeholder to redirect parts of the graph.
# - computed encoding goes into the encoding_mix
# - encoding mix output goes into the main biggan graph
@@ -273,22 +273,26 @@ class Interpolator:
def set_encoding(self, opt):
next_id = opt['id']
data = load_pickle(os.path.join(app_cfg.DIR_VECTORS, "file_{}.pkl".format(next_id)))
- encoding = np.expand_dims(data['encoding'], axis=0)
+ new_encoding = np.expand_dims(data['encoding'], axis=0)
+ new_label = np.expand_dims(data['label'], axis=0)
encoding_stored = self.lerp_params['encoding_stored']
encoding_mix = self.lerp_params['encoding_mix']
+ label = self.lerp_params['label']
# if we're showing an encoding already, lerp to the next one
if encoding_mix.n.value > 0:
- encoding_stored.switch(target_value=encoding)
+ encoding_stored.switch(target_value=new_encoding)
+ label.switch(target_value=new_label)
# otherwise (we're showing the latent)...
else:
# jump to the stored encoding, then switch
if encoding_stored.n.value < 0.5:
encoding_stored.n.value = 0
- encoding_stored.a.assign(encoding)
+ encoding_stored.a.assign(new_encoding)
else:
encoding_stored.n.value = 1
- encoding_stored.b.assign(encoding)
+ encoding_stored.b.assign(new_encoding)
encoding_mix.switch()
+ label.switch(target_value=new_label)
def on_step(self, i, dt, sess):
for param in self.sin_params.values():