import click from app.utils import click_utils from app.settings import app_cfg import os from os.path import join import time import numpy as np import random from subprocess import call import cv2 as cv from PIL import Image from glob import glob import tensorflow as tf import tensorflow_hub as hub import shutil import h5py tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) from app.search.json import save_params_latent, save_params_dense from app.search.image import image_to_uint8, imconvert_uint8, imconvert_float32, \ imread, imwrite, imgrid, resize_and_crop_image from app.search.vector import truncated_z_sample, truncated_z_single, create_labels @click.command('') @click.option('-i', '--input', 'opt_fp_in', required=True, help='Path to input image') @click.option('-d', '--dims', 'opt_dims', default=512, type=int, help='Dimensions of BigGAN network (128, 256, 512)') @click.option('-s', '--steps', 'opt_steps', default=500, type=int, help='Number of optimization iterations') @click.option('-l', '--limit', 'opt_limit', default=1000, type=int, help='Limit the number of images to process') @click.option('-v', '--video', 'opt_video', is_flag=True, help='Export a video for each dataset') @click.option('-t', '--tag', 'opt_tag', default='inverse_' + str(int(time.time() * 1000)), help='Tag this dataset') # @click.option('-r', '--recursive', 'opt_recursive', is_flag=True) @click.pass_context def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag): """ Search for an image (class vector) in BigGAN using gradient descent """ sess = tf.compat.v1.Session() if os.path.isdir(opt_fp_in): paths = glob(os.path.join(opt_fp_in, '*.jpg')) + \ glob(os.path.join(opt_fp_in, '*.jpeg')) + \ glob(os.path.join(opt_fp_in, '*.png')) else: paths = [opt_fp_in] fp_inverses = os.path.join(app_cfg.DIR_INVERSES, opt_tag) os.makedirs(fp_inverses, exist_ok=True) save_params_latent(fp_inverses, opt_tag) save_params_dense(fp_inverses, opt_tag) out_file = h5py.File(join(fp_inverses, 'dataset.hdf5'), 'w') out_images = out_file.create_dataset('xtrain', (len(paths), 3, 512, 512,), dtype='float32') out_labels = out_file.create_dataset('ytrain', (len(paths), 1000,), dtype='float32') out_latent = out_file.create_dataset('ztrain', (len(paths), 128,), dtype='float32') out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype()) for index, path in enumerate(paths): if index == opt_limit: break out_fns[index] = os.path.basename(path) fp_frames = find_nearest_vector(sess, path, opt_dims, out_images, out_labels, out_latent, opt_steps, index) if opt_video: export_video(fp_frames) def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_latent, opt_steps, index): """ Find the closest latent and class vectors for an image. Store the class vector in an HDF5. """ generator = hub.Module('https://tfhub.dev/deepmind/biggan-512/2') # inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k) # for k, v in generator.get_input_info_dict().items()} batch_size = 1 truncation = 1.0 z_dim = 128 vocab_size = 1000 img_size = 512 num_channels = 3 z_initial = truncated_z_sample(batch_size, z_dim, truncation/2) y_initial = create_labels(batch_size, vocab_size, 10) z_lr = 0.001 y_lr = 0.00001 input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2)) input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1)) input_trunc = tf.compat.v1.constant(1.0) output = generator({ 'z': input_z, 'y': input_y, 'truncation': input_trunc, }) target = tf.compat.v1.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels)) # loss = tf.losses.compute_weighted_loss(tf.square(output - target), weights=mask) loss = tf.compat.v1.losses.mean_squared_error(target, output) train_step_z = tf.train.AdamOptimizer(z_lr).minimize(loss, var_list=[input_z], name='AdamOpterZ') train_step_y = tf.train.AdamOptimizer(y_lr).minimize(loss, var_list=[input_y], name='AdamOpterY') target_im, fp_frames = load_target_image(opt_fp_in) # crop image and convert to format for next script phi_target_for_inversion = resize_and_crop_image(target_im, 512) b = np.dsplit(phi_target_for_inversion, 3) phi_target_for_inversion = np.stack(b).reshape((3, 512, 512)) out_images[index] = phi_target_for_inversion # create phi target for the latent / label pass phi_target = resize_and_crop_image(target_im, opt_dims) phi_target = np.expand_dims(phi_target, 0) phi_target = np.repeat(phi_target, batch_size, axis=0) # IMPORTANT: initialize variables before running the session sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.tables_initializer()) feed_dict = { target: phi_target, } # feed_dict = {input_z: z, input_y: y, input_trunc: truncation} phi_start = sess.run(output, feed_dict=feed_dict) start_im = imconvert_uint8(phi_start) imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im) try: print("Iteration start") for i in range(opt_steps): curr_loss, _, _ = sess.run([loss, train_step_z, train_step_y], feed_dict=feed_dict) phi_guess = sess.run(output) guess_im = imconvert_uint8(phi_guess) imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(i)), guess_im) if i % 20 == 0: print('iter: {}, loss: {}'.format(i, curr_loss)) except KeyboardInterrupt: pass z_guess = sess.run(input_z) y_guess = sess.run(input_y) out_labels[index] = y_guess out_latent[index] = z_guess return fp_frames def export_video(fp_frames): print("Exporting video...") cmd = [ '/home/lens/bin/ffmpeg', '-y', # '-v', 'quiet', '-r', '30', '-i', join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_%04d.png'), '-pix_fmt', 'yuv420p', join(app_cfg.DIR_OUTPUTS, fp_frames + '.mp4') ] # print(' '.join(cmd)) call(cmd) shutil.rmtree(join(app_cfg.DIR_OUTPUTS, fp_frames)) def load_target_image(opt_fp_in): print("Loading {}".format(opt_fp_in)) fn = os.path.basename(opt_fp_in) fbase, ext = os.path.splitext(fn) fp_frames = "frames_{}_{}".format(fbase, int(time.time() * 1000)) fp_frames_fullpath = join(app_cfg.DIR_OUTPUTS, fp_frames) print("Output to {}".format(fp_frames_fullpath)) os.makedirs(fp_frames_fullpath, exist_ok=True) target_im = imread(opt_fp_in) return target_im, fp_frames