summaryrefslogtreecommitdiff
path: root/cli/app/commands
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-01-06 17:54:57 +0100
committerJules Laplace <julescarbon@gmail.com>2020-01-06 17:54:57 +0100
commit9560db910b876ba5a249f3262bd4e05aa3fa2c2e (patch)
treef4558cef3f23646a1b887c77f3646bd4d27a0a09 /cli/app/commands
parent1084fad3e5fc2a2d70276fbe8cba5e6dfea10dff (diff)
class search works
Diffstat (limited to 'cli/app/commands')
-rw-r--r--cli/app/commands/biggan/search_class.py43
1 files changed, 21 insertions, 22 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py
index cbf39b2..b681e1f 100644
--- a/cli/app/commands/biggan/search_class.py
+++ b/cli/app/commands/biggan/search_class.py
@@ -31,7 +31,7 @@ from app.search.vector import truncated_z_sample, truncated_z_single, create_lab
help='Path to input image')
@click.option('-d', '--dims', 'opt_dims', default=512, type=int,
help='Dimensions of BigGAN network (128, 256, 512)')
-@click.option('-s', '--steps', 'opt_steps', default=2000, type=int,
+@click.option('-s', '--steps', 'opt_steps', default=1000, type=int,
help='Number of optimization iterations')
@click.option('-l', '--limit', 'opt_limit', default=1000, type=int,
help='Limit the number of images to process')
@@ -48,6 +48,8 @@ def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag):
sess = tf.compat.v1.Session()
+ generator = hub.Module('https://tfhub.dev/deepmind/biggan-512/2')
+
if os.path.isdir(opt_fp_in):
paths = glob(os.path.join(opt_fp_in, '*.jpg')) + \
glob(os.path.join(opt_fp_in, '*.jpeg')) + \
@@ -62,21 +64,23 @@ def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag):
out_file = h5py.File(join(fp_inverses, 'dataset.hdf5'), 'w')
out_images = out_file.create_dataset('xtrain', (len(paths), 3, 512, 512,), dtype='float32')
out_labels = out_file.create_dataset('ytrain', (len(paths), 1000,), dtype='float32')
- out_latent = out_file.create_dataset('ztrain', (len(paths), 128,), dtype='float32')
+ out_latent = out_file.create_dataset('latent', (len(paths), 128,), dtype='float32')
out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype())
for index, path in enumerate(paths):
if index == opt_limit:
break
out_fns[index] = os.path.basename(path)
- fp_frames = find_nearest_vector(sess, path, opt_dims, out_images, out_labels, out_latent, opt_steps, index)
+ fp_frames = find_nearest_vector(sess, generator, path, opt_dims, out_images, out_labels, out_latent, opt_steps, index)
if opt_video:
export_video(fp_frames)
-def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_latent, opt_steps, index):
+def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_labels, out_latent, opt_steps, index):
"""
Find the closest latent and class vectors for an image. Store the class vector in an HDF5.
"""
- generator = hub.Module('https://tfhub.dev/deepmind/biggan-512/2')
+ gen_signature = 'generator'
+ if 'generator' not in generator.get_signature_names():
+ gen_signature = 'default'
# inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k)
# for k, v in generator.get_input_info_dict().items()}
@@ -87,7 +91,6 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
vocab_size = 1000
img_size = 512
num_channels = 3
- save_step = 20
z_initial = truncated_z_sample(batch_size, z_dim, truncation/2)
y_initial = create_labels(batch_size, vocab_size, 10)
@@ -95,16 +98,17 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
z_lr = 0.001
y_lr = 0.001
- input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32)
- input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32)
- input_z_sigmoid = tf.compat.v1.sigmoid(input_z) * 2.0 - 1.0
- input_y_sigmoid = tf.compat.v1.sigmoid(input_y)
+ input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2))
+ input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1))
input_trunc = tf.compat.v1.constant(1.0)
output = generator({
- 'z': input_z_sigmoid,
- 'y': input_y_sigmoid,
+ 'z': input_z,
+ 'y': input_y,
'truncation': input_trunc,
- })
+ }, signature=gen_signature)
+
+ layer_name = 'module_apply_' + gen_signature + '/' + "Generator_2/G_Z/Reshape:0"
+ gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name)
target = tf.compat.v1.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels))
@@ -120,7 +124,6 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
phi_target_for_inversion = resize_and_crop_image(target_im, 512)
b = np.dsplit(phi_target_for_inversion, 3)
phi_target_for_inversion = np.stack(b).reshape((3, 512, 512))
- out_images[index] = phi_target_for_inversion
# create phi target for the latent / label pass
phi_target = resize_and_crop_image(target_im, opt_dims)
@@ -135,25 +138,21 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
target: phi_target,
}
- # feed_dict = {input_z: z, input_y: y, input_trunc: truncation}
- phi_start = sess.run(output)
- start_im = imgrid(imconvert_uint8(phi_start), cols=1)
- imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im)
-
try:
print("Preparing to iterate...")
for i in range(opt_steps):
curr_loss, _, _ = sess.run([loss, train_step_z, train_step_y], feed_dict=feed_dict)
- if i % save_step == 0:
+ if i % 20 == 0:
phi_guess = sess.run(output)
guess_im = imgrid(imconvert_uint8(phi_guess), cols=1)
- imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(int(i / save_step))), guess_im)
+ imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(int(i / 20))), guess_im)
print('iter: {}, loss: {}'.format(i, curr_loss))
except KeyboardInterrupt:
pass
- z_guess, y_guess = sess.run([input_z_sigmoid, input_y_sigmoid])
+ z_guess, y_guess = sess.run(input_z, input_y)
+ out_images[index] = phi_target_for_inversion
out_labels[index] = y_guess
out_latent[index] = z_guess
return fp_frames