summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-11 03:17:48 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-11 03:17:48 +0100
commita982f0ce441cf2321b7bcae63578db4dea20690d (patch)
treef33aa3d926265e36bc23bdc3998d41d7104f5514
parent3e0c18eebeab329fa700efecf61b66dc60d342f5 (diff)
hdf5 workaround
-rw-r--r--inversion/image_inversion.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/inversion/image_inversion.py b/inversion/image_inversion.py
index 9aa2ca1..1201cf3 100644
--- a/inversion/image_inversion.py
+++ b/inversion/image_inversion.py
@@ -301,10 +301,10 @@ def small_init(shape=[BATCH_SIZE, Z_DIM]):
# --------------------------
if params.dataset.endswith('.hdf5'):
in_file = h5py.File(params.dataset, 'r')
- sample_images = in_file['xtrain'].value
+ sample_images = in_file['xtrain'][()]
if COND_GAN:
- sample_labels = in_file['ytrain'].value
- sample_fns = in_file['fn'].value
+ sample_labels = in_file['ytrain'][()]
+ sample_fns = in_file['fn'][()]
NUM_IMGS = sample_images.shape[0] # number of images to be inverted.
print("Number of images: {}".format(NUM_IMGS))
def sample_images_gen():
@@ -324,6 +324,7 @@ if params.dataset.endswith('.hdf5'):
latent_gen = sample_latent_gen()
if NUM_IMGS % BATCH_SIZE != 0:
REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE)
+ NUM_IMGS += REMAINDER
sample_images += sample_images[-REMAINDER]
sample_labels += sample_labels[-REMAINDER]
sample_fns += sample_fns[-REMAINDER]