diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2020-01-06 14:23:39 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2020-01-06 14:23:39 +0100 |
| commit | a5098c8dca8ad18117a105e1fe47205de9b5f5ec (patch) | |
| tree | f1632fa5f3ea722521a201f1f11e5b119dd21ec6 | |
| parent | 242eb958ddc4e08759349687b90f1555f0f4e23a (diff) | |
search class
| -rw-r--r-- | cli/app/commands/biggan/search_class.py | 169 |
1 files changed, 169 insertions, 0 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py new file mode 100644 index 0000000..6339e8b --- /dev/null +++ b/cli/app/commands/biggan/search_class.py @@ -0,0 +1,169 @@ +import click + +from app.utils import click_utils +from app.settings import app_cfg + +import os +from os.path import join +import time +import numpy as np +import random +from subprocess import call +import cv2 as cv +from PIL import Image +from glob import glob +import tensorflow as tf +import tensorflow_hub as hub +import shutil +import h5py + +from app.search.json import save_params_latent, save_params_dense +from app.search.image import image_to_uint8, imconvert_uint8, imconvert_float32, \ + imread, imwrite, imgrid, resize_and_crop_image +from app.search.vector import truncated_z_sample, truncated_z_single, create_labels + +@click.command('') +@click.option('-i', '--input', 'opt_fp_in', required=True, + help='Path to input image') +@click.option('-d', '--dims', 'opt_dims', default=512, type=int, + help='Dimensions of BigGAN network (128, 256, 512)') +@click.option('-s', '--steps', 'opt_steps', default=500, type=int, + help='Number of optimization iterations') +@click.option('-l', '--limit', 'opt_limit', default=1000, type=int, + help='Limit the number of images to process') +@click.option('-v', '--video', 'opt_video', is_flag=True, + help='Export a video for each dataset') +@click.option('-t', '--tag', 'opt_tag', default='inverse_' + str(int(time.time() * 1000)), + help='Tag this dataset') +# @click.option('-r', '--recursive', 'opt_recursive', is_flag=True) +@click.pass_context +def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag): + """ + Search for an image (class vector) in BigGAN using gradient descent + """ + vocab_size = input_y.shape.as_list()[1] + + sess = tf.compat.v1.Session() + sess.run(tf.compat.v1.global_variables_initializer()) + sess.run(tf.compat.v1.tables_initializer()) + + if os.path.isdir(opt_fp_in): + paths = glob(os.path.join(opt_fp_in, '*.jpg')) + \ + glob(os.path.join(opt_fp_in, '*.jpeg')) + \ + glob(os.path.join(opt_fp_in, '*.png')) + else: + paths = [opt_fp_in] + + fp_inverses = os.path.join(app_cfg.DIR_INVERSES, opt_tag) + os.makedirs(fp_inverses, exist_ok=True) + save_params_latent(fp_inverses, opt_tag) + save_params_dense(fp_inverses, opt_tag) + out_file = h5py.File(join(fp_inverses, 'dataset.hdf5'), 'w') + out_images = out_file.create_dataset('xtrain', (len(paths), 3, 512, 512,), dtype='float32') + out_labels = out_file.create_dataset('ytrain', (len(paths), vocab_size,), dtype='float32') + out_latent = out_file.create_dataset('ztrain', (len(paths), 128,), dtype='float32') + out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype()) + for index, path in enumerate(paths): + if index == opt_limit: + break + out_fns[index] = os.path.basename(path) + fp_frames = find_nearest_vector(generator, sess, path, opt_dims, out_images, out_labels, opt_steps, index) + if opt_video: + export_video(fp_frames) + +def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_latent, opt_steps, index): + """ + Find the closest latent and class vectors for an image. Store the class vector in an HDF5. + """ + generator = hub.Module('https://tfhub.dev/deepmind/biggan-512/2') + + # inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k) + # for k, v in generator.get_input_info_dict().items()} + batch_size = 1 + truncation = 1.0 + + z_dim = 512 + vocab_size = 1000 + num_channels = 3 + + z_initial = truncated_z_sample(batch_size, z_dim, truncation/2) + y_initial = create_labels(batch_size, vocab_size, 10) + + z_lr = 0.001 + y_lr = 0.00001 + + input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2)) + input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1)) + input_trunc = tf.compat.v1.constant(1.0) + output = generator({ + 'z': input_z, + 'y': input_y, + 'truncation': input_trunc, + }) + + target = tf.placeholder(tf.float32, shape=(batch_size, z_dim, z_dim, num_channels)) + + # loss = tf.losses.compute_weighted_loss(tf.square(output - target), weights=mask) + loss = tf.losses.mean_squared_error(target, output) + + train_step_z = tf.train.AdamOptimizer(z_lr).minimize(loss, var_list=[input_z], name='AdamOpterZ') + train_step_y = tf.train.AdamOptimizer(y_lr).minimize(loss, var_list=[input_y], name='AdamOpterY') + + target_im, fp_frames = load_target_image(opt_fp_in) + + # crop image and convert to format for next script + phi_target_for_inversion = resize_and_crop_image(target_im, 512) + b = np.dsplit(phi_target_for_inversion, 3) + phi_target_for_inversion = np.stack(b).reshape((3, 512, 512)) + out_images[index] = phi_target_for_inversion + + # create phi target for the latent / label pass + phi_target = resize_and_crop_image(target_im, opt_dims) + phi_target = np.expand_dims(phi_target, 0) + phi_target = np.repeat(phi_target, batch_size, axis=0) + + # feed_dict = {input_z: z, input_y: y, input_trunc: truncation} + phi_start = sess.run(output) + start_im = imconvert_uint8(phi_start) + imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im) + + try: + for i in range(opt_steps): + sess.run([train_step_z, train_step_y]) + + phi_guess = sess.run(output) + guess_im = imconvert_uint8(phi_guess) + imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(i)), guess_im) + if i % 20 == 0: + print('lr: {}, iter: {}, grad_z: {}, grad_y: {}'.format(lr_z, i, np.std(grad_z), np.std(grad_y))) + except KeyboardInterrupt: + pass + + z_guess = sess.run(input_z) + y_guess = sess.run(input_y) + out_labels[index] = y_guess + out_latent[index] = z_guess + return fp_frames + +def export_video(fp_frames): + print("Exporting video...") + cmd = [ + '/home/lens/bin/ffmpeg', + '-y', # '-v', 'quiet', + '-r', '30', + '-i', join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_%04d.png'), + '-pix_fmt', 'yuv420p', + join(app_cfg.DIR_OUTPUTS, fp_frames + '.mp4') + ] + # print(' '.join(cmd)) + call(cmd) + shutil.rmtree(join(app_cfg.DIR_OUTPUTS, fp_frames)) + +def load_target_image(opt_fp_in): + print("Processing {}".format(opt_fp_in)) + fn = os.path.basename(opt_fp_in) + fbase, ext = os.path.splitext(fn) + fp_frames = "frames_{}_{}".format(fbase, int(time.time() * 1000)) + os.makedirs(join(app_cfg.DIR_OUTPUTS, fp_frames), exist_ok=True) + target_im = imread(opt_fp_in) + return target_im, fp_frames |
