summaryrefslogtreecommitdiff
path: root/cli/app/search/search_dense.py
diff options
context:
space:
mode:
Diffstat (limited to 'cli/app/search/search_dense.py')
-rw-r--r--cli/app/search/search_dense.py28
1 files changed, 18 insertions, 10 deletions
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py
index 0086db5..392fa70 100644
--- a/cli/app/search/search_dense.py
+++ b/cli/app/search/search_dense.py
@@ -100,6 +100,8 @@ def find_dense_embedding_for_images(params):
# --------------------------
# Load Graph.
# --------------------------
+ tf.reset_default_graph()
+
generator = hub.Module(str(params.generator_path))
gen_signature = 'generator'
@@ -419,6 +421,9 @@ def find_dense_embedding_for_images(params):
out_fns[:] = sample_fns[:NUM_IMGS_TO_PROCESS]
+ vector_dir = os.path.join(app_cfg.INVERSES_DIR, "vectors")
+ os.makedirs(vector_dir, exist_ok=True)
+
# Gradient descent w.r.t. generator's inputs.
it = 0
out_pos = 0
@@ -485,19 +490,22 @@ def find_dense_embedding_for_images(params):
# write encoding, latent to pkl file
for i in range(BATCH_SIZE):
out_i = out_pos + i
- fn, ext = os.path.splitext(sample_fns[out_i])
- fp_out_pkl = os.path.join(app_cfg.INVERSES_DIR, fn ".pkl")
- out_data = {
- 'id': fn,
- 'latent': out_lat[out_i],
- 'encoding': out_enc[out_i],
- 'label': out_labels[out_i],
- }
- write_pickle(out_data, fp_out_pkl)
+ sample_fn, ext = os.path.splitext(sample_fns[out_i])
image = Image.fromarray(images[i])
fp = BytesIO()
image.save(fp, format='png')
- upload_bytes_to_cortex(params.folder_id, fn, fp, 'image/png')
+ data = upload_bytes_to_cortex(params.folder_id, sample_fn + "-inverse.png", fp, "image/png")
+ if data is not None:
+ file_id = data['id']
+ fp_out_pkl = os.path.join(vector_dir, "file_{}.pkl".format(file_id))
+ out_data = {
+ 'id': file_id,
+ 'sample_fn': sample_fn,
+ 'label': out_labels[out_i],
+ 'latent': out_lat[out_i],
+ 'encoding': out_enc[out_i],
+ }
+ write_pickle(out_data, fp_out_pkl)
out_pos += BATCH_SIZE
if params.max_batches > 0 and (out_pos / BATCH_SIZE) >= params.max_batches: