summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--inversion/image_inversion.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/inversion/image_inversion.py b/inversion/image_inversion.py
index b044190..9aa2ca1 100644
--- a/inversion/image_inversion.py
+++ b/inversion/image_inversion.py
@@ -301,10 +301,10 @@ def small_init(shape=[BATCH_SIZE, Z_DIM]):
# --------------------------
if params.dataset.endswith('.hdf5'):
in_file = h5py.File(params.dataset, 'r')
- sample_images = in_file['xtrain']
+ sample_images = in_file['xtrain'].value
if COND_GAN:
- sample_labels = in_file['ytrain']
- sample_fns = in_file['fn']
+ sample_labels = in_file['ytrain'].value
+ sample_fns = in_file['fn'].value
NUM_IMGS = sample_images.shape[0] # number of images to be inverted.
print("Number of images: {}".format(NUM_IMGS))
def sample_images_gen():
@@ -324,9 +324,9 @@ if params.dataset.endswith('.hdf5'):
latent_gen = sample_latent_gen()
if NUM_IMGS % BATCH_SIZE != 0:
REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE)
- sample_images = sample_images.append(sample_images[-REMAINDER])
- sample_labels = sample_labels.append(sample_labels[-REMAINDER])
- sample_fns = sample_fns.append(sample_fns[-REMAINDER])
+ sample_images += sample_images[-REMAINDER]
+ sample_labels += sample_labels[-REMAINDER]
+ sample_fns += sample_fns[-REMAINDER]
assert(NUM_IMGS % BATCH_SIZE == 0)
else:
sys.exit('Unknown dataset {}.'.format(params.dataset))