diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2020-01-06 17:54:57 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2020-01-06 17:54:57 +0100 |
| commit | 9560db910b876ba5a249f3262bd4e05aa3fa2c2e (patch) | |
| tree | f4558cef3f23646a1b887c77f3646bd4d27a0a09 | |
| parent | 1084fad3e5fc2a2d70276fbe8cba5e6dfea10dff (diff) | |
class search works
| -rw-r--r-- | cli/app/commands/biggan/search_class.py | 43 |
1 files changed, 21 insertions, 22 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py index cbf39b2..b681e1f 100644 --- a/cli/app/commands/biggan/search_class.py +++ b/cli/app/commands/biggan/search_class.py @@ -31,7 +31,7 @@ from app.search.vector import truncated_z_sample, truncated_z_single, create_lab help='Path to input image') @click.option('-d', '--dims', 'opt_dims', default=512, type=int, help='Dimensions of BigGAN network (128, 256, 512)') -@click.option('-s', '--steps', 'opt_steps', default=2000, type=int, +@click.option('-s', '--steps', 'opt_steps', default=1000, type=int, help='Number of optimization iterations') @click.option('-l', '--limit', 'opt_limit', default=1000, type=int, help='Limit the number of images to process') @@ -48,6 +48,8 @@ def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag): sess = tf.compat.v1.Session() + generator = hub.Module('https://tfhub.dev/deepmind/biggan-512/2') + if os.path.isdir(opt_fp_in): paths = glob(os.path.join(opt_fp_in, '*.jpg')) + \ glob(os.path.join(opt_fp_in, '*.jpeg')) + \ @@ -62,21 +64,23 @@ def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag): out_file = h5py.File(join(fp_inverses, 'dataset.hdf5'), 'w') out_images = out_file.create_dataset('xtrain', (len(paths), 3, 512, 512,), dtype='float32') out_labels = out_file.create_dataset('ytrain', (len(paths), 1000,), dtype='float32') - out_latent = out_file.create_dataset('ztrain', (len(paths), 128,), dtype='float32') + out_latent = out_file.create_dataset('latent', (len(paths), 128,), dtype='float32') out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype()) for index, path in enumerate(paths): if index == opt_limit: break out_fns[index] = os.path.basename(path) - fp_frames = find_nearest_vector(sess, path, opt_dims, out_images, out_labels, out_latent, opt_steps, index) + fp_frames = find_nearest_vector(sess, generator, path, opt_dims, out_images, out_labels, out_latent, opt_steps, index) if opt_video: export_video(fp_frames) -def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_latent, opt_steps, index): +def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_labels, out_latent, opt_steps, index): """ Find the closest latent and class vectors for an image. Store the class vector in an HDF5. """ - generator = hub.Module('https://tfhub.dev/deepmind/biggan-512/2') + gen_signature = 'generator' + if 'generator' not in generator.get_signature_names(): + gen_signature = 'default' # inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k) # for k, v in generator.get_input_info_dict().items()} @@ -87,7 +91,6 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l vocab_size = 1000 img_size = 512 num_channels = 3 - save_step = 20 z_initial = truncated_z_sample(batch_size, z_dim, truncation/2) y_initial = create_labels(batch_size, vocab_size, 10) @@ -95,16 +98,17 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l z_lr = 0.001 y_lr = 0.001 - input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32) - input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32) - input_z_sigmoid = tf.compat.v1.sigmoid(input_z) * 2.0 - 1.0 - input_y_sigmoid = tf.compat.v1.sigmoid(input_y) + input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2)) + input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1)) input_trunc = tf.compat.v1.constant(1.0) output = generator({ - 'z': input_z_sigmoid, - 'y': input_y_sigmoid, + 'z': input_z, + 'y': input_y, 'truncation': input_trunc, - }) + }, signature=gen_signature) + + layer_name = 'module_apply_' + gen_signature + '/' + "Generator_2/G_Z/Reshape:0" + gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) target = tf.compat.v1.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels)) @@ -120,7 +124,6 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l phi_target_for_inversion = resize_and_crop_image(target_im, 512) b = np.dsplit(phi_target_for_inversion, 3) phi_target_for_inversion = np.stack(b).reshape((3, 512, 512)) - out_images[index] = phi_target_for_inversion # create phi target for the latent / label pass phi_target = resize_and_crop_image(target_im, opt_dims) @@ -135,25 +138,21 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l target: phi_target, } - # feed_dict = {input_z: z, input_y: y, input_trunc: truncation} - phi_start = sess.run(output) - start_im = imgrid(imconvert_uint8(phi_start), cols=1) - imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im) - try: print("Preparing to iterate...") for i in range(opt_steps): curr_loss, _, _ = sess.run([loss, train_step_z, train_step_y], feed_dict=feed_dict) - if i % save_step == 0: + if i % 20 == 0: phi_guess = sess.run(output) guess_im = imgrid(imconvert_uint8(phi_guess), cols=1) - imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(int(i / save_step))), guess_im) + imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(int(i / 20))), guess_im) print('iter: {}, loss: {}'.format(i, curr_loss)) except KeyboardInterrupt: pass - z_guess, y_guess = sess.run([input_z_sigmoid, input_y_sigmoid]) + z_guess, y_guess = sess.run(input_z, input_y) + out_images[index] = phi_target_for_inversion out_labels[index] = y_guess out_latent[index] = z_guess return fp_frames |
