summaryrefslogtreecommitdiff
path: root/cli/app/search/search_dense.py
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-01-10 19:45:20 +0100
committerJules Laplace <julescarbon@gmail.com>2020-01-10 19:45:20 +0100
commit735b26c5b88ae1de1eac134f6d9a626ded33eeac (patch)
tree1caecff9a1bee5d2bf570a0f494998a22b91a5ba /cli/app/search/search_dense.py
parentf462fe520aca4fd15127fd6d0b27e342e2f23a14 (diff)
load latents and labels, share interpolation amount
Diffstat (limited to 'cli/app/search/search_dense.py')
-rw-r--r--cli/app/search/search_dense.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py
index c7cf078..616ba6a 100644
--- a/cli/app/search/search_dense.py
+++ b/cli/app/search/search_dense.py
@@ -280,7 +280,7 @@ def find_dense_embedding_for_images(params):
params.inv_it / params.decay_n, 0.1, staircase=True)
else:
lrate = tf.constant(params.lr)
- trained_params = [encoding] if params.fixed_z else [latent, encoding]
+ trained_params = [latent, encoding]
optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999)
inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params,
global_step=inv_step)
@@ -406,12 +406,12 @@ def find_dense_embedding_for_images(params):
it += 1
# Save images that are ready.
- latent_batch, enc_batch, rec_err_batch = sess.run([latent, encoding, img_rec_err])
- out_lat[out_pos:out_pos+BATCH_SIZE] = latent_batch
- out_enc[out_pos:out_pos+BATCH_SIZE] = enc_batch
+ label_trained, latent_trained, enc_trained, rec_err_trained = sess.run([label, latent, encoding, img_rec_err])
+ out_lat[out_pos:out_pos+BATCH_SIZE] = latent_trained
+ out_enc[out_pos:out_pos+BATCH_SIZE] = enc_trained
out_images[out_pos:out_pos+BATCH_SIZE] = image_batch
- out_labels[out_pos:out_pos+BATCH_SIZE] = label_batch
- out_err[out_pos:out_pos+BATCH_SIZE] = rec_err_batch
+ out_labels[out_pos:out_pos+BATCH_SIZE] = label_trained
+ out_err[out_pos:out_pos+BATCH_SIZE] = rec_err_trained
gen_images = sess.run(gen_img_orig)
images = vs.data2img(gen_images)