import click from app.utils import click_utils from app.settings import app_cfg import os from os.path import join import time import numpy as np import random from subprocess import call import cv2 as cv from PIL import Image from glob import glob import tensorflow as tf import tensorflow_hub as hub import shutil import h5py from app.search.json import save_params_latent, save_params_dense from app.search.image import image_to_uint8, imconvert_uint8, imconvert_float32, \ imread, imwrite, imgrid, resize_and_crop_image from app.search.vector import truncated_z_sample, truncated_z_single, create_labels @click.command('') @click.option('-i', '--input', 'opt_fp_in', required=True, help='Path to input image') @click.option('-d', '--dims', 'opt_dims', default=128, type=int, help='Dimensions of BigGAN network (128, 256, 512)') @click.option('-s', '--steps', 'opt_steps', default=500, type=int, help='Number of optimization iterations') @click.option('-l', '--limit', 'opt_limit', default=1000, type=int, help='Limit the number of images to process') @click.option('-v', '--video', 'opt_video', is_flag=True, help='Export a video for each dataset') @click.option('-t', '--tag', 'opt_tag', default='inverse_' + str(int(time.time() * 1000)), help='Tag this dataset') # @click.option('-r', '--recursive', 'opt_recursive', is_flag=True) @click.pass_context def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag): """ Search for an image (class vector) in BigGAN using gradient descent """ generator = hub.Module('https://tfhub.dev/deepmind/biggan-' + str(opt_dims) + '/2') inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k) for k, v in generator.get_input_info_dict().items()} input_z = inputs['z'] input_y = inputs['y'] input_trunc = inputs['truncation'] output = generator(inputs) vocab_size = input_y.shape.as_list()[1] sess = tf.compat.v1.Session() sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.tables_initializer()) if os.path.isdir(opt_fp_in): paths = glob(os.path.join(opt_fp_in, '*.jpg')) + \ glob(os.path.join(opt_fp_in, '*.jpeg')) + \ glob(os.path.join(opt_fp_in, '*.png')) else: paths = [opt_fp_in] fp_inverses = os.path.join(app_cfg.DIR_INVERSES, opt_tag) os.makedirs(fp_inverses, exist_ok=True) save_params_latent(fp_inverses, opt_tag) save_params_dense(fp_inverses, opt_tag) out_file = h5py.File(join(fp_inverses, 'dataset.hdf5'), 'w') out_images = out_file.create_dataset('xtrain', (len(paths), 3, 512, 512,), dtype='float32') out_labels = out_file.create_dataset('ytrain', (len(paths), vocab_size,), dtype='float32') out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype()) for index, path in enumerate(paths): if index == opt_limit: break out_fns[index] = os.path.basename(path) fp_frames = find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, path, opt_dims, out_images, out_labels, opt_steps, index) if opt_video: export_video(fp_frames) def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims, out_images, out_labels, opt_steps, index): """ Find the closest latent and class vectors for an image. Store the class vector in an HDF5. """ z_dim = input_z.shape.as_list()[1] vocab_size = input_y.shape.as_list()[1] # scalar truncation value in [0.02, 1.0] batch_size = 25 truncation = 1.0 z = truncated_z_sample(batch_size, z_dim, truncation/2) num_classes = 1 y = create_labels(batch_size, vocab_size, num_classes) if opt_fp_in: print("Processing {}".format(opt_fp_in)) fn = os.path.basename(opt_fp_in) fbase, ext = os.path.splitext(fn) fp_frames = "frames_{}_{}".format(fbase, int(time.time() * 1000)) os.makedirs(join(app_cfg.DIR_OUTPUTS, fp_frames), exist_ok=True) target_im = imread(opt_fp_in) # crop image to 512 and save for later processing phi_target_for_inversion = resize_and_crop_image(target_im, 512) b = np.dsplit(phi_target_for_inversion, 3) phi_target_for_inversion = np.stack(b).reshape((3, 512, 512)) out_images[index] = phi_target_for_inversion # crop image to 128 to find vectors phi_target = resize_and_crop_image(target_im, opt_dims) phi_target = np.expand_dims(phi_target, 0) phi_target = np.repeat(phi_target, batch_size, axis=0) else: print("Processing random vector") fp_frames = "frames_{}".format(int(time.time() * 1000)) os.makedirs(join(app_cfg.DIR_OUTPUTS, fp_frames), exist_ok=True) z_target = np.repeat(truncated_z_single(z_dim, truncation), batch_size, axis=0) y_target = np.repeat(create_labels(1, vocab_size, 1), batch_size, axis=0) feed_dict = {input_z: z_target, input_y: y_target, input_trunc: truncation} phi_target = sess.run(output, feed_dict=feed_dict) target_im = imgrid(imconvert_uint8(phi_target), cols=5) imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_target.png'), target_im) #dy_dx = g.gradient(y, x) cost = tf.reduce_sum(tf.pow(output - phi_target, 2)) dc_dz, dc_dy, = tf.gradients(cost, [input_z, input_y]) #dc_dy, = tf.gradients(cost, [input_y]) lr_z = 0.0001 lr_y = 0.000001 #z = truncated_z_sample(batch_size, z_dim, truncation/2) feed_dict = {input_z: z, input_y: y, input_trunc: truncation} phi_start = sess.run(output, feed_dict=feed_dict) start_im = imgrid(imconvert_uint8(phi_start), cols=5) imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im) try: for i in range(opt_steps): feed_dict = {input_z: z, input_y: y, input_trunc: truncation} grad_z, grad_y = sess.run([dc_dz, dc_dy], feed_dict=feed_dict) z -= grad_z * lr_z y -= grad_y * lr_y lr_z *= 0.997 lr_y *= 0.999 if i % 30 == 0: # lr_y *= 1.002 y = np.clip(y, 0, 1) # for j in range(batch_size): # y[j] /= y[j].sum() if i > 200 and i % 100 == 0: n = i / opt_steps mean = np.mean(y, axis=0) y = y * (1 - n) + mean * n indices = np.logical_or(z <= -2*truncation, z >= +2*truncation) z[indices] = np.random.randn(np.count_nonzero(indices)) feed_dict = {input_z: z, input_y: y, input_trunc: truncation} phi_guess = sess.run(output, feed_dict=feed_dict) guess_im = imgrid(imconvert_uint8(phi_guess), cols=5) imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(i)), guess_im) if i % 20 == 0: print('lr: {}, iter: {}, grad_z: {}, grad_y: {}'.format(lr_z, i, np.std(grad_z), np.std(grad_y))) except KeyboardInterrupt: pass print(y.shape) out_labels[index] = np.mean(y, axis=0) return fp_frames def export_video(fp_frames): print("Exporting video...") cmd = [ '/home/lens/bin/ffmpeg', '-y', # '-v', 'quiet', '-r', '30', '-i', join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_%04d.png'), '-pix_fmt', 'yuv420p', join(app_cfg.DIR_OUTPUTS, fp_frames + '.mp4') ] # print(' '.join(cmd)) call(cmd) shutil.rmtree(join(app_cfg.DIR_OUTPUTS, fp_frames))