summaryrefslogtreecommitdiff
path: root/inversion
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-11 17:44:53 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-11 17:44:53 +0100
commiteb4d235b415782105fcb4409a04b5706d22bb5d8 (patch)
treeae8b4339d90d270999ff0702b498df1fdac7283c /inversion
parent4e7c8b0c30cda9cb4e7e991ad9d1313b4f8c9d6f (diff)
image inversion
Diffstat (limited to 'inversion')
-rw-r--r--inversion/image_inversion.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/inversion/image_inversion.py b/inversion/image_inversion.py
index a4cec7f..2c182c1 100644
--- a/inversion/image_inversion.py
+++ b/inversion/image_inversion.py
@@ -309,6 +309,7 @@ if params.dataset.endswith('.hdf5'):
sample_fns = in_file['fn'][()]
NUM_IMGS = sample_images.shape[0] # number of images to be inverted.
print("Number of images: {}".format(NUM_IMGS))
+ print("Batch size: {}".format(BATCH_SIZE))
def sample_images_gen():
for i in range(int(NUM_IMGS / BATCH_SIZE)):
i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE
@@ -327,9 +328,9 @@ if params.dataset.endswith('.hdf5'):
if NUM_IMGS % BATCH_SIZE != 0:
REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE)
NUM_IMGS += REMAINDER
- sample_images += sample_images[-REMAINDER]
- sample_labels += sample_labels[-REMAINDER]
- sample_fns += sample_fns[-REMAINDER]
+ sample_images += sample_images[:-REMAINDER]
+ sample_labels += sample_labels[:-REMAINDER]
+ sample_fns += sample_fns[:-REMAINDER]
assert(NUM_IMGS % BATCH_SIZE == 0)
else:
sys.exit('Unknown dataset {}.'.format(params.dataset))
@@ -344,15 +345,16 @@ sess.run(tf.tables_initializer())
# Output file.
out_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'w')
-out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE,
- dtype='uint8')
-out_enc = out_file.create_dataset('encoding', [NUM_IMGS,] + ENC_SHAPE)
-out_lat = out_file.create_dataset('latent', [NUM_IMGS, Z_DIM])
+out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, dtype='float32')
+out_enc = out_file.create_dataset('encoding', [NUM_IMGS,] + ENC_SHAPE, dtype='float32')
+out_lat = out_file.create_dataset('latent', [NUM_IMGS, Z_DIM], dtype='float32')
out_fns = out_file.create_dataset('fn', [NUM_IMGS], dtype=h5py.string_dtype())
if COND_GAN:
- out_labels = out_file.create_dataset('ytrain', (NUM_IMGS, N_CLASS,), dtype='uint32')
+ out_labels = out_file.create_dataset('ytrain', (NUM_IMGS, N_CLASS,), dtype='float32')
out_err = out_file.create_dataset('err', (NUM_IMGS,))
+out_fns[:] = sample_fns
+
# Gradient descent w.r.t. generator's inputs.
it = 0
out_pos = 0