summaryrefslogtreecommitdiff
path: root/cli/app/commands/biggan
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-10 23:59:19 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-10 23:59:19 +0100
commitf702091a6f54698b03f9eba2702f023dc6358fbd (patch)
treeb07c88e2085bb47ef29cb60db34e39289f6b33f8 /cli/app/commands/biggan
parent1cdbf220659d847fdd3855a62f9cba080347271f (diff)
steps
Diffstat (limited to 'cli/app/commands/biggan')
-rw-r--r--cli/app/commands/biggan/search.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/cli/app/commands/biggan/search.py b/cli/app/commands/biggan/search.py
index 47d91f7..9ee440e 100644
--- a/cli/app/commands/biggan/search.py
+++ b/cli/app/commands/biggan/search.py
@@ -25,15 +25,17 @@ from app.search.vector import truncated_z_sample, truncated_z_single, create_lab
@click.command('')
@click.option('-i', '--input', 'opt_fp_in', required=True,
help='Path to input image')
-@click.option('-s', '--dims', 'opt_dims', default=128, type=int,
+@click.option('-d', '--dims', 'opt_dims', default=128, type=int,
help='Dimensions of BigGAN network (128, 256, 512)')
+@click.option('-s', '--steps', 'opt_steps', default=500, type=int,
+ help='Number of optimization iterations')
@click.option('-v', '--video', 'opt_video', is_flag=True,
help='Export a video for each dataset')
@click.option('-t', '--tag', 'opt_tag', default='inverse_' + str(int(time.time() * 1000)),
help='Tag this dataset')
# @click.option('-r', '--recursive', 'opt_recursive', is_flag=True)
@click.pass_context
-def cli(ctx, opt_fp_in, opt_dims, opt_video, opt_tag):
+def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_video, opt_tag):
"""
Search for an image (class vector) in BigGAN using gradient descent
"""
@@ -68,12 +70,12 @@ def cli(ctx, opt_fp_in, opt_dims, opt_video, opt_tag):
out_labels = out_file.create_dataset('ytrain', (len(paths), vocab_size,), dtype='float32')
out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype())
for path, index in enumerate(paths):
- out_labels[index] = os.path.basename(path)
- fp_frames = find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, path, opt_dims, out_images, out_labels, index)
+ out_fns[index] = os.path.basename(path)
+ fp_frames = find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, path, opt_dims, out_images, out_labels, opt_steps, index)
if opt_video:
export_video(fp_frames)
-def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims, out_images, out_labels, index):
+def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims, out_images, out_labels, opt_steps, index):
"""
Find the closest latent and class vectors for an image. Store the class vector in an HDF5.
"""
@@ -134,7 +136,7 @@ def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output,
imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im)
try:
- for i in range(500):
+ for i in range(opt_steps):
feed_dict = {input_z: z, input_y: y, input_trunc: truncation}
grad_z, grad_y = sess.run([dc_dz, dc_dy], feed_dict=feed_dict)
z -= grad_z * lr_z
@@ -163,6 +165,7 @@ def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output,
print('lr: {}, iter: {}, grad_z: {}, grad_y: {}'.format(lr_z, i, np.std(grad_z), np.std(grad_y)))
except KeyboardInterrupt:
pass
+ print(y.shape)
out_labels[index] = y
return fp_frames