summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cli/app/commands/biggan/search.py6
-rw-r--r--inversion/image_inversion.py3
2 files changed, 8 insertions, 1 deletions
diff --git a/cli/app/commands/biggan/search.py b/cli/app/commands/biggan/search.py
index d2b5900..9e43458 100644
--- a/cli/app/commands/biggan/search.py
+++ b/cli/app/commands/biggan/search.py
@@ -29,13 +29,15 @@ from app.search.vector import truncated_z_sample, truncated_z_single, create_lab
help='Dimensions of BigGAN network (128, 256, 512)')
@click.option('-s', '--steps', 'opt_steps', default=500, type=int,
help='Number of optimization iterations')
+@click.option('-l', '--limit', 'opt_limit', default=1000, type=int,
+ help='Limit the number of images to process')
@click.option('-v', '--video', 'opt_video', is_flag=True,
help='Export a video for each dataset')
@click.option('-t', '--tag', 'opt_tag', default='inverse_' + str(int(time.time() * 1000)),
help='Tag this dataset')
# @click.option('-r', '--recursive', 'opt_recursive', is_flag=True)
@click.pass_context
-def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_video, opt_tag):
+def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag):
"""
Search for an image (class vector) in BigGAN using gradient descent
"""
@@ -70,6 +72,8 @@ def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_video, opt_tag):
out_labels = out_file.create_dataset('ytrain', (len(paths), vocab_size,), dtype='float32')
out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype())
for index, path in enumerate(paths):
+ if index == opt_limit:
+ break
out_fns[index] = os.path.basename(path)
fp_frames = find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, path, opt_dims, out_images, out_labels, opt_steps, index)
if opt_video:
diff --git a/inversion/image_inversion.py b/inversion/image_inversion.py
index 38c5261..b044190 100644
--- a/inversion/image_inversion.py
+++ b/inversion/image_inversion.py
@@ -304,6 +304,7 @@ if params.dataset.endswith('.hdf5'):
sample_images = in_file['xtrain']
if COND_GAN:
sample_labels = in_file['ytrain']
+ sample_fns = in_file['fn']
NUM_IMGS = sample_images.shape[0] # number of images to be inverted.
print("Number of images: {}".format(NUM_IMGS))
def sample_images_gen():
@@ -325,6 +326,7 @@ if params.dataset.endswith('.hdf5'):
REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE)
sample_images = sample_images.append(sample_images[-REMAINDER])
sample_labels = sample_labels.append(sample_labels[-REMAINDER])
+ sample_fns = sample_fns.append(sample_fns[-REMAINDER])
assert(NUM_IMGS % BATCH_SIZE == 0)
else:
sys.exit('Unknown dataset {}.'.format(params.dataset))
@@ -343,6 +345,7 @@ out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE,
dtype='uint8')
out_enc = out_file.create_dataset('encoding', [NUM_IMGS,] + ENC_SHAPE)
out_lat = out_file.create_dataset('latent', [NUM_IMGS, Z_DIM])
+out_fns = out_file.create_dataset('fn', [NUM_IMGS], dtype=h5py.string_dtype())
if COND_GAN:
out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32')
out_err = out_file.create_dataset('err', (NUM_IMGS,))