diff options
| -rw-r--r-- | cli/app/commands/biggan/search.py | 6 | ||||
| -rw-r--r-- | inversion/image_inversion.py | 3 |
2 files changed, 8 insertions, 1 deletions
diff --git a/cli/app/commands/biggan/search.py b/cli/app/commands/biggan/search.py index d2b5900..9e43458 100644 --- a/cli/app/commands/biggan/search.py +++ b/cli/app/commands/biggan/search.py @@ -29,13 +29,15 @@ from app.search.vector import truncated_z_sample, truncated_z_single, create_lab help='Dimensions of BigGAN network (128, 256, 512)') @click.option('-s', '--steps', 'opt_steps', default=500, type=int, help='Number of optimization iterations') +@click.option('-l', '--limit', 'opt_limit', default=1000, type=int, + help='Limit the number of images to process') @click.option('-v', '--video', 'opt_video', is_flag=True, help='Export a video for each dataset') @click.option('-t', '--tag', 'opt_tag', default='inverse_' + str(int(time.time() * 1000)), help='Tag this dataset') # @click.option('-r', '--recursive', 'opt_recursive', is_flag=True) @click.pass_context -def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_video, opt_tag): +def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag): """ Search for an image (class vector) in BigGAN using gradient descent """ @@ -70,6 +72,8 @@ def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_video, opt_tag): out_labels = out_file.create_dataset('ytrain', (len(paths), vocab_size,), dtype='float32') out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype()) for index, path in enumerate(paths): + if index == opt_limit: + break out_fns[index] = os.path.basename(path) fp_frames = find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, path, opt_dims, out_images, out_labels, opt_steps, index) if opt_video: diff --git a/inversion/image_inversion.py b/inversion/image_inversion.py index 38c5261..b044190 100644 --- a/inversion/image_inversion.py +++ b/inversion/image_inversion.py @@ -304,6 +304,7 @@ if params.dataset.endswith('.hdf5'): sample_images = in_file['xtrain'] if COND_GAN: sample_labels = in_file['ytrain'] + sample_fns = in_file['fn'] NUM_IMGS = sample_images.shape[0] # number of images to be inverted. print("Number of images: {}".format(NUM_IMGS)) def sample_images_gen(): @@ -325,6 +326,7 @@ if params.dataset.endswith('.hdf5'): REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE) sample_images = sample_images.append(sample_images[-REMAINDER]) sample_labels = sample_labels.append(sample_labels[-REMAINDER]) + sample_fns = sample_fns.append(sample_fns[-REMAINDER]) assert(NUM_IMGS % BATCH_SIZE == 0) else: sys.exit('Unknown dataset {}.'.format(params.dataset)) @@ -343,6 +345,7 @@ out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, dtype='uint8') out_enc = out_file.create_dataset('encoding', [NUM_IMGS,] + ENC_SHAPE) out_lat = out_file.create_dataset('latent', [NUM_IMGS, Z_DIM]) +out_fns = out_file.create_dataset('fn', [NUM_IMGS], dtype=h5py.string_dtype()) if COND_GAN: out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32') out_err = out_file.create_dataset('err', (NUM_IMGS,)) |
