summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--inversion/image_inversion.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/inversion/image_inversion.py b/inversion/image_inversion.py
index 9aa2ca1..1201cf3 100644
--- a/inversion/image_inversion.py
+++ b/inversion/image_inversion.py
@@ -301,10 +301,10 @@ def small_init(shape=[BATCH_SIZE, Z_DIM]):
# --------------------------
if params.dataset.endswith('.hdf5'):
in_file = h5py.File(params.dataset, 'r')
- sample_images = in_file['xtrain'].value
+ sample_images = in_file['xtrain'][()]
if COND_GAN:
- sample_labels = in_file['ytrain'].value
- sample_fns = in_file['fn'].value
+ sample_labels = in_file['ytrain'][()]
+ sample_fns = in_file['fn'][()]
NUM_IMGS = sample_images.shape[0] # number of images to be inverted.
print("Number of images: {}".format(NUM_IMGS))
def sample_images_gen():
@@ -324,6 +324,7 @@ if params.dataset.endswith('.hdf5'):
latent_gen = sample_latent_gen()
if NUM_IMGS % BATCH_SIZE != 0:
REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE)
+ NUM_IMGS += REMAINDER
sample_images += sample_images[-REMAINDER]
sample_labels += sample_labels[-REMAINDER]
sample_fns += sample_fns[-REMAINDER]