summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-01-15 21:55:16 +0100
committerJules Laplace <julescarbon@gmail.com>2020-01-15 21:55:16 +0100
commit0e996eed57590c294f1995b379fb116dee400a1e (patch)
treeaab150038453985aec3a120fc0bebe34f235bf2b
parentfbf7e6f3968beb7df30f20e99f494f03aafbb939 (diff)
params...
-rw-r--r--cli/app/search/search_dense.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py
index 6bd73ce..5aee392 100644
--- a/cli/app/search/search_dense.py
+++ b/cli/app/search/search_dense.py
@@ -327,9 +327,6 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
sample_labels = in_file['ytrain'][()]
sample_fns = in_file['fn'][()]
NUM_IMGS = sample_images.shape[0] # number of images to be inverted.
- if NUM_IMGS < BATCH_SIZE:
- BATCH_SIZE = NUM_IMGS
- SAMPLE_SIZE = NUM_IMGS
print("Number of images: {}".format(NUM_IMGS))
print("Batch size: {}".format(BATCH_SIZE))
def sample_images_gen():
@@ -343,9 +340,10 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE
yield sample_latents[i_1:i_2]
latent_gen = sample_latent_gen()
- if NUM_IMGS % BATCH_SIZE != 0:
- REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE)
- NUM_IMGS += REMAINDER
+ TOTAL_IMGS = NUM_IMGS
+ while TOTAL_IMGS % BATCH_SIZE != 0:
+ REMAINDER = 1 # BATCH_SIZE - (NUM_IMGS % BATCH_SIZE)
+ TOTAL_IMGS += REMAINDER
sample_images = np.append(sample_images, sample_images[-REMAINDER:,...], axis=0)
sample_labels = np.append(sample_labels, sample_labels[-REMAINDER:,...], axis=0)
sample_latents = np.append(sample_latents, sample_latents[-REMAINDER:,...], axis=0)
@@ -445,6 +443,8 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
# write encoding, latent to pkl file
for i in range(BATCH_SIZE):
out_i = out_pos + i
+ if out_i >= NUM_IMGS:
+ continue
sample_fn, ext = os.path.splitext(sample_fns[out_i])
image = Image.fromarray(images[i])
fp = BytesIO()