diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2020-01-15 21:55:16 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2020-01-15 21:55:16 +0100 |
| commit | 0e996eed57590c294f1995b379fb116dee400a1e (patch) | |
| tree | aab150038453985aec3a120fc0bebe34f235bf2b | |
| parent | fbf7e6f3968beb7df30f20e99f494f03aafbb939 (diff) | |
params...
| -rw-r--r-- | cli/app/search/search_dense.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py index 6bd73ce..5aee392 100644 --- a/cli/app/search/search_dense.py +++ b/cli/app/search/search_dense.py @@ -327,9 +327,6 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op sample_labels = in_file['ytrain'][()] sample_fns = in_file['fn'][()] NUM_IMGS = sample_images.shape[0] # number of images to be inverted. - if NUM_IMGS < BATCH_SIZE: - BATCH_SIZE = NUM_IMGS - SAMPLE_SIZE = NUM_IMGS print("Number of images: {}".format(NUM_IMGS)) print("Batch size: {}".format(BATCH_SIZE)) def sample_images_gen(): @@ -343,9 +340,10 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE yield sample_latents[i_1:i_2] latent_gen = sample_latent_gen() - if NUM_IMGS % BATCH_SIZE != 0: - REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE) - NUM_IMGS += REMAINDER + TOTAL_IMGS = NUM_IMGS + while TOTAL_IMGS % BATCH_SIZE != 0: + REMAINDER = 1 # BATCH_SIZE - (NUM_IMGS % BATCH_SIZE) + TOTAL_IMGS += REMAINDER sample_images = np.append(sample_images, sample_images[-REMAINDER:,...], axis=0) sample_labels = np.append(sample_labels, sample_labels[-REMAINDER:,...], axis=0) sample_latents = np.append(sample_latents, sample_latents[-REMAINDER:,...], axis=0) @@ -445,6 +443,8 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op # write encoding, latent to pkl file for i in range(BATCH_SIZE): out_i = out_pos + i + if out_i >= NUM_IMGS: + continue sample_fn, ext = os.path.splitext(sample_fns[out_i]) image = Image.fromarray(images[i]) fp = BytesIO() |
