summaryrefslogtreecommitdiff
path: root/inversion/image_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'inversion/image_inversion.py')
-rw-r--r--inversion/image_inversion.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/inversion/image_inversion.py b/inversion/image_inversion.py
index 38c5261..b044190 100644
--- a/inversion/image_inversion.py
+++ b/inversion/image_inversion.py
@@ -304,6 +304,7 @@ if params.dataset.endswith('.hdf5'):
sample_images = in_file['xtrain']
if COND_GAN:
sample_labels = in_file['ytrain']
+ sample_fns = in_file['fn']
NUM_IMGS = sample_images.shape[0] # number of images to be inverted.
print("Number of images: {}".format(NUM_IMGS))
def sample_images_gen():
@@ -325,6 +326,7 @@ if params.dataset.endswith('.hdf5'):
REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE)
sample_images = sample_images.append(sample_images[-REMAINDER])
sample_labels = sample_labels.append(sample_labels[-REMAINDER])
+ sample_fns = sample_fns.append(sample_fns[-REMAINDER])
assert(NUM_IMGS % BATCH_SIZE == 0)
else:
sys.exit('Unknown dataset {}.'.format(params.dataset))
@@ -343,6 +345,7 @@ out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE,
dtype='uint8')
out_enc = out_file.create_dataset('encoding', [NUM_IMGS,] + ENC_SHAPE)
out_lat = out_file.create_dataset('latent', [NUM_IMGS, Z_DIM])
+out_fns = out_file.create_dataset('fn', [NUM_IMGS], dtype=h5py.string_dtype())
if COND_GAN:
out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32')
out_err = out_file.create_dataset('err', (NUM_IMGS,))