summaryrefslogtreecommitdiff
path: root/inversion/image_inversion.py
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-11 02:35:44 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-11 02:35:44 +0100
commite1b64603c7b555d9af431478937dbef60c4b1d99 (patch)
treef75410f19aac3cd5981f371346d545d73994cebd /inversion/image_inversion.py
parent339f9ae16c26cfe598d0c019a98ef58885f756bc (diff)
store fns in output
Diffstat (limited to 'inversion/image_inversion.py')
-rw-r--r--inversion/image_inversion.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/inversion/image_inversion.py b/inversion/image_inversion.py
index 38c5261..b044190 100644
--- a/inversion/image_inversion.py
+++ b/inversion/image_inversion.py
@@ -304,6 +304,7 @@ if params.dataset.endswith('.hdf5'):
sample_images = in_file['xtrain']
if COND_GAN:
sample_labels = in_file['ytrain']
+ sample_fns = in_file['fn']
NUM_IMGS = sample_images.shape[0] # number of images to be inverted.
print("Number of images: {}".format(NUM_IMGS))
def sample_images_gen():
@@ -325,6 +326,7 @@ if params.dataset.endswith('.hdf5'):
REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE)
sample_images = sample_images.append(sample_images[-REMAINDER])
sample_labels = sample_labels.append(sample_labels[-REMAINDER])
+ sample_fns = sample_fns.append(sample_fns[-REMAINDER])
assert(NUM_IMGS % BATCH_SIZE == 0)
else:
sys.exit('Unknown dataset {}.'.format(params.dataset))
@@ -343,6 +345,7 @@ out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE,
dtype='uint8')
out_enc = out_file.create_dataset('encoding', [NUM_IMGS,] + ENC_SHAPE)
out_lat = out_file.create_dataset('latent', [NUM_IMGS, Z_DIM])
+out_fns = out_file.create_dataset('fn', [NUM_IMGS], dtype=h5py.string_dtype())
if COND_GAN:
out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32')
out_err = out_file.create_dataset('err', (NUM_IMGS,))