summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-01-06 14:23:39 +0100
committerJules Laplace <julescarbon@gmail.com>2020-01-06 14:23:39 +0100
commita5098c8dca8ad18117a105e1fe47205de9b5f5ec (patch)
treef1632fa5f3ea722521a201f1f11e5b119dd21ec6
parent242eb958ddc4e08759349687b90f1555f0f4e23a (diff)
search class
-rw-r--r--cli/app/commands/biggan/search_class.py169
1 files changed, 169 insertions, 0 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py
new file mode 100644
index 0000000..6339e8b
--- /dev/null
+++ b/cli/app/commands/biggan/search_class.py
@@ -0,0 +1,169 @@
+import click
+
+from app.utils import click_utils
+from app.settings import app_cfg
+
+import os
+from os.path import join
+import time
+import numpy as np
+import random
+from subprocess import call
+import cv2 as cv
+from PIL import Image
+from glob import glob
+import tensorflow as tf
+import tensorflow_hub as hub
+import shutil
+import h5py
+
+from app.search.json import save_params_latent, save_params_dense
+from app.search.image import image_to_uint8, imconvert_uint8, imconvert_float32, \
+ imread, imwrite, imgrid, resize_and_crop_image
+from app.search.vector import truncated_z_sample, truncated_z_single, create_labels
+
+@click.command('')
+@click.option('-i', '--input', 'opt_fp_in', required=True,
+ help='Path to input image')
+@click.option('-d', '--dims', 'opt_dims', default=512, type=int,
+ help='Dimensions of BigGAN network (128, 256, 512)')
+@click.option('-s', '--steps', 'opt_steps', default=500, type=int,
+ help='Number of optimization iterations')
+@click.option('-l', '--limit', 'opt_limit', default=1000, type=int,
+ help='Limit the number of images to process')
+@click.option('-v', '--video', 'opt_video', is_flag=True,
+ help='Export a video for each dataset')
+@click.option('-t', '--tag', 'opt_tag', default='inverse_' + str(int(time.time() * 1000)),
+ help='Tag this dataset')
+# @click.option('-r', '--recursive', 'opt_recursive', is_flag=True)
+@click.pass_context
+def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag):
+ """
+ Search for an image (class vector) in BigGAN using gradient descent
+ """
+ vocab_size = input_y.shape.as_list()[1]
+
+ sess = tf.compat.v1.Session()
+ sess.run(tf.compat.v1.global_variables_initializer())
+ sess.run(tf.compat.v1.tables_initializer())
+
+ if os.path.isdir(opt_fp_in):
+ paths = glob(os.path.join(opt_fp_in, '*.jpg')) + \
+ glob(os.path.join(opt_fp_in, '*.jpeg')) + \
+ glob(os.path.join(opt_fp_in, '*.png'))
+ else:
+ paths = [opt_fp_in]
+
+ fp_inverses = os.path.join(app_cfg.DIR_INVERSES, opt_tag)
+ os.makedirs(fp_inverses, exist_ok=True)
+ save_params_latent(fp_inverses, opt_tag)
+ save_params_dense(fp_inverses, opt_tag)
+ out_file = h5py.File(join(fp_inverses, 'dataset.hdf5'), 'w')
+ out_images = out_file.create_dataset('xtrain', (len(paths), 3, 512, 512,), dtype='float32')
+ out_labels = out_file.create_dataset('ytrain', (len(paths), vocab_size,), dtype='float32')
+ out_latent = out_file.create_dataset('ztrain', (len(paths), 128,), dtype='float32')
+ out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype())
+ for index, path in enumerate(paths):
+ if index == opt_limit:
+ break
+ out_fns[index] = os.path.basename(path)
+ fp_frames = find_nearest_vector(generator, sess, path, opt_dims, out_images, out_labels, opt_steps, index)
+ if opt_video:
+ export_video(fp_frames)
+
+def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_latent, opt_steps, index):
+ """
+ Find the closest latent and class vectors for an image. Store the class vector in an HDF5.
+ """
+ generator = hub.Module('https://tfhub.dev/deepmind/biggan-512/2')
+
+ # inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k)
+ # for k, v in generator.get_input_info_dict().items()}
+ batch_size = 1
+ truncation = 1.0
+
+ z_dim = 512
+ vocab_size = 1000
+ num_channels = 3
+
+ z_initial = truncated_z_sample(batch_size, z_dim, truncation/2)
+ y_initial = create_labels(batch_size, vocab_size, 10)
+
+ z_lr = 0.001
+ y_lr = 0.00001
+
+ input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2))
+ input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1))
+ input_trunc = tf.compat.v1.constant(1.0)
+ output = generator({
+ 'z': input_z,
+ 'y': input_y,
+ 'truncation': input_trunc,
+ })
+
+ target = tf.placeholder(tf.float32, shape=(batch_size, z_dim, z_dim, num_channels))
+
+ # loss = tf.losses.compute_weighted_loss(tf.square(output - target), weights=mask)
+ loss = tf.losses.mean_squared_error(target, output)
+
+ train_step_z = tf.train.AdamOptimizer(z_lr).minimize(loss, var_list=[input_z], name='AdamOpterZ')
+ train_step_y = tf.train.AdamOptimizer(y_lr).minimize(loss, var_list=[input_y], name='AdamOpterY')
+
+ target_im, fp_frames = load_target_image(opt_fp_in)
+
+ # crop image and convert to format for next script
+ phi_target_for_inversion = resize_and_crop_image(target_im, 512)
+ b = np.dsplit(phi_target_for_inversion, 3)
+ phi_target_for_inversion = np.stack(b).reshape((3, 512, 512))
+ out_images[index] = phi_target_for_inversion
+
+ # create phi target for the latent / label pass
+ phi_target = resize_and_crop_image(target_im, opt_dims)
+ phi_target = np.expand_dims(phi_target, 0)
+ phi_target = np.repeat(phi_target, batch_size, axis=0)
+
+ # feed_dict = {input_z: z, input_y: y, input_trunc: truncation}
+ phi_start = sess.run(output)
+ start_im = imconvert_uint8(phi_start)
+ imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im)
+
+ try:
+ for i in range(opt_steps):
+ sess.run([train_step_z, train_step_y])
+
+ phi_guess = sess.run(output)
+ guess_im = imconvert_uint8(phi_guess)
+ imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(i)), guess_im)
+ if i % 20 == 0:
+ print('lr: {}, iter: {}, grad_z: {}, grad_y: {}'.format(lr_z, i, np.std(grad_z), np.std(grad_y)))
+ except KeyboardInterrupt:
+ pass
+
+ z_guess = sess.run(input_z)
+ y_guess = sess.run(input_y)
+ out_labels[index] = y_guess
+ out_latent[index] = z_guess
+ return fp_frames
+
+def export_video(fp_frames):
+ print("Exporting video...")
+ cmd = [
+ '/home/lens/bin/ffmpeg',
+ '-y', # '-v', 'quiet',
+ '-r', '30',
+ '-i', join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_%04d.png'),
+ '-pix_fmt', 'yuv420p',
+ join(app_cfg.DIR_OUTPUTS, fp_frames + '.mp4')
+ ]
+ # print(' '.join(cmd))
+ call(cmd)
+ shutil.rmtree(join(app_cfg.DIR_OUTPUTS, fp_frames))
+
+def load_target_image(opt_fp_in):
+ print("Processing {}".format(opt_fp_in))
+ fn = os.path.basename(opt_fp_in)
+ fbase, ext = os.path.splitext(fn)
+ fp_frames = "frames_{}_{}".format(fbase, int(time.time() * 1000))
+ os.makedirs(join(app_cfg.DIR_OUTPUTS, fp_frames), exist_ok=True)
+ target_im = imread(opt_fp_in)
+ return target_im, fp_frames