diff options
Diffstat (limited to 'cli/app/commands')
| -rw-r--r-- | cli/app/commands/biggan/search.py | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/cli/app/commands/biggan/search.py b/cli/app/commands/biggan/search.py index 47d91f7..9ee440e 100644 --- a/cli/app/commands/biggan/search.py +++ b/cli/app/commands/biggan/search.py @@ -25,15 +25,17 @@ from app.search.vector import truncated_z_sample, truncated_z_single, create_lab @click.command('') @click.option('-i', '--input', 'opt_fp_in', required=True, help='Path to input image') -@click.option('-s', '--dims', 'opt_dims', default=128, type=int, +@click.option('-d', '--dims', 'opt_dims', default=128, type=int, help='Dimensions of BigGAN network (128, 256, 512)') +@click.option('-s', '--steps', 'opt_steps', default=500, type=int, + help='Number of optimization iterations') @click.option('-v', '--video', 'opt_video', is_flag=True, help='Export a video for each dataset') @click.option('-t', '--tag', 'opt_tag', default='inverse_' + str(int(time.time() * 1000)), help='Tag this dataset') # @click.option('-r', '--recursive', 'opt_recursive', is_flag=True) @click.pass_context -def cli(ctx, opt_fp_in, opt_dims, opt_video, opt_tag): +def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_video, opt_tag): """ Search for an image (class vector) in BigGAN using gradient descent """ @@ -68,12 +70,12 @@ def cli(ctx, opt_fp_in, opt_dims, opt_video, opt_tag): out_labels = out_file.create_dataset('ytrain', (len(paths), vocab_size,), dtype='float32') out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype()) for path, index in enumerate(paths): - out_labels[index] = os.path.basename(path) - fp_frames = find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, path, opt_dims, out_images, out_labels, index) + out_fns[index] = os.path.basename(path) + fp_frames = find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, path, opt_dims, out_images, out_labels, opt_steps, index) if opt_video: export_video(fp_frames) -def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims, out_images, out_labels, index): +def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims, out_images, out_labels, opt_steps, index): """ Find the closest latent and class vectors for an image. Store the class vector in an HDF5. """ @@ -134,7 +136,7 @@ def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im) try: - for i in range(500): + for i in range(opt_steps): feed_dict = {input_z: z, input_y: y, input_trunc: truncation} grad_z, grad_y = sess.run([dc_dz, dc_dy], feed_dict=feed_dict) z -= grad_z * lr_z @@ -163,6 +165,7 @@ def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, print('lr: {}, iter: {}, grad_z: {}, grad_y: {}'.format(lr_z, i, np.std(grad_z), np.std(grad_y))) except KeyboardInterrupt: pass + print(y.shape) out_labels[index] = y return fp_frames |
