diff options
| -rw-r--r-- | inversion/image_inversion.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/inversion/image_inversion.py b/inversion/image_inversion.py index 9aa2ca1..1201cf3 100644 --- a/inversion/image_inversion.py +++ b/inversion/image_inversion.py @@ -301,10 +301,10 @@ def small_init(shape=[BATCH_SIZE, Z_DIM]): # -------------------------- if params.dataset.endswith('.hdf5'): in_file = h5py.File(params.dataset, 'r') - sample_images = in_file['xtrain'].value + sample_images = in_file['xtrain'][()] if COND_GAN: - sample_labels = in_file['ytrain'].value - sample_fns = in_file['fn'].value + sample_labels = in_file['ytrain'][()] + sample_fns = in_file['fn'][()] NUM_IMGS = sample_images.shape[0] # number of images to be inverted. print("Number of images: {}".format(NUM_IMGS)) def sample_images_gen(): @@ -324,6 +324,7 @@ if params.dataset.endswith('.hdf5'): latent_gen = sample_latent_gen() if NUM_IMGS % BATCH_SIZE != 0: REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE) + NUM_IMGS += REMAINDER sample_images += sample_images[-REMAINDER] sample_labels += sample_labels[-REMAINDER] sample_fns += sample_fns[-REMAINDER] |
