import click from app.utils import click_utils from app.settings import app_cfg import os from os.path import join os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import time import numpy as np import random from subprocess import call import cv2 as cv from PIL import Image from glob import glob import tensorflow as tf import tensorflow_hub as hub import shutil import h5py tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) from app.search.json import save_params_latent, save_params_dense from app.search.image import image_to_uint8, imconvert_uint8, imconvert_float32, \ imread, imwrite, imgrid, resize_and_crop_image from app.search.vector import truncated_z_sample, truncated_z_single, \ create_labels, create_labels_uniform @click.command('') @click.option('-i', '--input', 'opt_fp_in', required=True, help='Path to input image') @click.option('-d', '--dims', 'opt_dims', default=512, type=int, help='Dimensions of BigGAN network (128, 256, 512)') @click.option('-s', '--steps', 'opt_steps', default=2000, type=int, help='Number of optimization iterations') @click.option('-l', '--limit', 'opt_limit', default=1000, type=int, help='Limit the number of images to process') @click.option('-v', '--video', 'opt_video', is_flag=True, help='Export a video for each dataset') @click.option('-t', '--tag', 'opt_tag', default='inverse_' + str(int(time.time() * 1000)), help='Tag this dataset') # @click.option('-r', '--recursive', 'opt_recursive', is_flag=True) @click.pass_context def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag): """ Search for an image (class vector) in BigGAN using gradient descent """ sess = tf.compat.v1.Session() generator = hub.Module('https://tfhub.dev/deepmind/biggan-512/2') if os.path.isdir(opt_fp_in): paths = glob(os.path.join(opt_fp_in, '*.jpg')) + \ glob(os.path.join(opt_fp_in, '*.jpeg')) + \ glob(os.path.join(opt_fp_in, '*.png')) else: paths = [opt_fp_in] fp_inverses = os.path.join(app_cfg.DIR_INVERSES, opt_tag) os.makedirs(fp_inverses, exist_ok=True) save_params_latent(fp_inverses, opt_tag) save_params_dense(fp_inverses, opt_tag) out_file = h5py.File(join(fp_inverses, 'dataset.hdf5'), 'w') out_images = out_file.create_dataset('xtrain', (len(paths), 3, 512, 512,), dtype='float32') out_labels = out_file.create_dataset('ytrain', (len(paths), 1000,), dtype='float32') out_latent = out_file.create_dataset('latent', (len(paths), 128,), dtype='float32') out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype()) for index, path in enumerate(paths): if index == opt_limit: break out_fns[index] = os.path.basename(path) fp_frames = find_nearest_vector(sess, generator, path, opt_dims, out_images, out_labels, out_latent, opt_steps, index) if opt_video: export_video(fp_frames) def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_labels, out_latent, opt_steps, index): """ Find the closest latent and class vectors for an image. Store the class vector in an HDF5. """ gen_signature = 'generator' if 'generator' not in generator.get_signature_names(): gen_signature = 'default' # inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k) # for k, v in generator.get_input_info_dict().items()} batch_size = 1 truncation = 1.0 z_dim = 128 vocab_size = 1000 img_size = 512 num_channels = 3 z_initial = truncated_z_sample(batch_size, z_dim, truncation/2) y_initial = create_labels_uniform(batch_size, vocab_size) z_lr = 0.001 y_lr = 0.001 input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2)) input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1)) input_trunc = tf.compat.v1.constant(1.0) output = generator({ 'z': input_z, 'y': input_y, 'truncation': input_trunc, }, signature=gen_signature) layer_name = 'module_apply_' + gen_signature + '/' + "Generator_2/G_Z/Reshape:0" gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) target = tf.compat.v1.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels)) # loss = tf.losses.compute_weighted_loss(tf.square(output - target), weights=mask) loss = tf.compat.v1.losses.mean_squared_error(target, output) train_step_z = tf.train.AdamOptimizer(z_lr).minimize(loss, var_list=[input_z], name='AdamOpterZ') train_step_y = tf.train.AdamOptimizer(y_lr).minimize(loss, var_list=[input_y], name='AdamOpterY') target_im, fp_frames = load_target_image(opt_fp_in) # crop image and convert to format for next script phi_target_for_inversion = resize_and_crop_image(target_im, 512) b = np.dsplit(phi_target_for_inversion, 3) phi_target_for_inversion = np.stack(b).reshape((3, 512, 512)) # create phi target for the latent / label pass phi_target = resize_and_crop_image(target_im, opt_dims) phi_target = np.expand_dims(phi_target, 0) phi_target = np.repeat(phi_target, batch_size, axis=0) # IMPORTANT: initialize variables before running the session sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.tables_initializer()) feed_dict = { target: phi_target, } try: print("Preparing to iterate...") for i in range(opt_steps): curr_loss, _, _ = sess.run([loss, train_step_z, train_step_y], feed_dict=feed_dict) if i % 20 == 0: phi_guess = sess.run(output) guess_im = imgrid(imconvert_uint8(phi_guess), cols=1) imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(int(i / 20))), guess_im) print('iter: {}, loss: {}'.format(i, curr_loss)) except KeyboardInterrupt: pass z_guess, y_guess = sess.run([input_z, input_y]) out_images[index] = phi_target_for_inversion out_labels[index] = y_guess out_latent[index] = z_guess return fp_frames def export_video(fp_frames): print("Exporting video...") cmd = [ '/home/lens/bin/ffmpeg', '-y', # '-v', 'quiet', '-r', '30', '-i', join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_%04d.png'), '-pix_fmt', 'yuv420p', join(app_cfg.DIR_OUTPUTS, fp_frames + '.mp4') ] # print(' '.join(cmd)) call(cmd) shutil.rmtree(join(app_cfg.DIR_OUTPUTS, fp_frames)) def load_target_image(opt_fp_in): print("Loading {}".format(opt_fp_in)) fn = os.path.basename(opt_fp_in) fbase, ext = os.path.splitext(fn) fp_frames = "frames_{}_{}".format(fbase, int(time.time() * 1000)) fp_frames_fullpath = join(app_cfg.DIR_OUTPUTS, fp_frames) print("Output to {}".format(fp_frames_fullpath)) os.makedirs(fp_frames_fullpath, exist_ok=True) target_im = imread(opt_fp_in) return target_im, fp_frames