diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2019-12-11 17:44:53 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2019-12-11 17:44:53 +0100 |
| commit | eb4d235b415782105fcb4409a04b5706d22bb5d8 (patch) | |
| tree | ae8b4339d90d270999ff0702b498df1fdac7283c /inversion/image_inversion.py | |
| parent | 4e7c8b0c30cda9cb4e7e991ad9d1313b4f8c9d6f (diff) | |
image inversion
Diffstat (limited to 'inversion/image_inversion.py')
| -rw-r--r-- | inversion/image_inversion.py | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/inversion/image_inversion.py b/inversion/image_inversion.py index a4cec7f..2c182c1 100644 --- a/inversion/image_inversion.py +++ b/inversion/image_inversion.py @@ -309,6 +309,7 @@ if params.dataset.endswith('.hdf5'): sample_fns = in_file['fn'][()] NUM_IMGS = sample_images.shape[0] # number of images to be inverted. print("Number of images: {}".format(NUM_IMGS)) + print("Batch size: {}".format(BATCH_SIZE)) def sample_images_gen(): for i in range(int(NUM_IMGS / BATCH_SIZE)): i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE @@ -327,9 +328,9 @@ if params.dataset.endswith('.hdf5'): if NUM_IMGS % BATCH_SIZE != 0: REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE) NUM_IMGS += REMAINDER - sample_images += sample_images[-REMAINDER] - sample_labels += sample_labels[-REMAINDER] - sample_fns += sample_fns[-REMAINDER] + sample_images += sample_images[:-REMAINDER] + sample_labels += sample_labels[:-REMAINDER] + sample_fns += sample_fns[:-REMAINDER] assert(NUM_IMGS % BATCH_SIZE == 0) else: sys.exit('Unknown dataset {}.'.format(params.dataset)) @@ -344,15 +345,16 @@ sess.run(tf.tables_initializer()) # Output file. out_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'w') -out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, - dtype='uint8') -out_enc = out_file.create_dataset('encoding', [NUM_IMGS,] + ENC_SHAPE) -out_lat = out_file.create_dataset('latent', [NUM_IMGS, Z_DIM]) +out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, dtype='float32') +out_enc = out_file.create_dataset('encoding', [NUM_IMGS,] + ENC_SHAPE, dtype='float32') +out_lat = out_file.create_dataset('latent', [NUM_IMGS, Z_DIM], dtype='float32') out_fns = out_file.create_dataset('fn', [NUM_IMGS], dtype=h5py.string_dtype()) if COND_GAN: - out_labels = out_file.create_dataset('ytrain', (NUM_IMGS, N_CLASS,), dtype='uint32') + out_labels = out_file.create_dataset('ytrain', (NUM_IMGS, N_CLASS,), dtype='float32') out_err = out_file.create_dataset('err', (NUM_IMGS,)) +out_fns[:] = sample_fns + # Gradient descent w.r.t. generator's inputs. it = 0 out_pos = 0 |
