summaryrefslogtreecommitdiff
path: root/cli/app/search/search_class.py
diff options
context:
space:
mode:
Diffstat (limited to 'cli/app/search/search_class.py')
-rw-r--r--cli/app/search/search_class.py153
1 files changed, 153 insertions, 0 deletions
diff --git a/cli/app/search/search_class.py b/cli/app/search/search_class.py
new file mode 100644
index 0000000..eb9ff42
--- /dev/null
+++ b/cli/app/search/search_class.py
@@ -0,0 +1,153 @@
+from app.settings import app_cfg
+
+import os
+from os.path import join
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
+
+import time
+import numpy as np
+import random
+from subprocess import call
+import cv2 as cv
+from PIL import Image
+from glob import glob
+import tensorflow as tf
+import tensorflow_hub as hub
+import shutil
+import h5py
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+from app.search.json import save_params_latent, save_params_dense
+from app.search.image import image_to_uint8, imconvert_uint8, imconvert_float32, \
+ imread, imwrite, imgrid, resize_and_crop_image
+from app.search.vector import truncated_z_sample, truncated_z_single, \
+ create_labels, create_labels_uniform
+
+def find_nearest_vector_for_images(opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag):
+ sess = tf.compat.v1.Session()
+
+ generator = hub.Module('https://tfhub.dev/deepmind/biggan-512/2')
+
+ if os.path.isdir(opt_fp_in):
+ paths = glob(os.path.join(opt_fp_in, '*.jpg')) + \
+ glob(os.path.join(opt_fp_in, '*.jpeg')) + \
+ glob(os.path.join(opt_fp_in, '*.png'))
+ else:
+ paths = [opt_fp_in]
+
+ fp_inverses = os.path.join(app_cfg.DIR_INVERSES, opt_tag)
+ os.makedirs(fp_inverses, exist_ok=True)
+ save_params_latent(fp_inverses, opt_tag)
+ save_params_dense(fp_inverses, opt_tag)
+ out_file = h5py.File(join(fp_inverses, 'dataset.hdf5'), 'w')
+ out_images = out_file.create_dataset('xtrain', (len(paths), 3, 512, 512,), dtype='float32')
+ out_labels = out_file.create_dataset('ytrain', (len(paths), 1000,), dtype='float32')
+ out_latent = out_file.create_dataset('latent', (len(paths), 128,), dtype='float32')
+ out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype())
+ for index, path in enumerate(paths):
+ if index == opt_limit:
+ break
+ out_fns[index] = os.path.basename(path)
+ fp_frames = find_nearest_vector(sess, generator, path, opt_dims, out_images, out_labels, out_latent, opt_steps, index)
+ if opt_video:
+ export_video(fp_frames)
+
+def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_labels, out_latent, opt_steps, index):
+ """
+ Find the closest latent and class vectors for an image. Store the class vector in an HDF5.
+ """
+ batch_size = 1
+ truncation = 1.0
+
+ z_dim = 128
+ vocab_size = 1000
+ img_size = 512
+ num_channels = 3
+
+ z_initial = truncated_z_sample(batch_size, z_dim, truncation/2)
+ y_initial = create_labels_uniform(batch_size, vocab_size)
+
+ z_lr = 0.001
+ y_lr = 0.001
+
+ input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2))
+ input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1))
+ input_trunc = tf.compat.v1.constant(1.0)
+ output = generator({
+ 'z': input_z,
+ 'y': input_y,
+ 'truncation': input_trunc,
+ })
+
+ target = tf.compat.v1.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels))
+
+ # loss = tf.losses.compute_weighted_loss(tf.square(output - target), weights=mask)
+ loss = tf.compat.v1.losses.mean_squared_error(target, output)
+
+ train_step_z = tf.train.AdamOptimizer(z_lr).minimize(loss, var_list=[input_z], name='AdamOpterZ')
+ train_step_y = tf.train.AdamOptimizer(y_lr).minimize(loss, var_list=[input_y], name='AdamOpterY')
+
+ target_im, fp_frames = load_target_image(opt_fp_in)
+
+ # crop image and convert to format for next script
+ phi_target_for_inversion = resize_and_crop_image(target_im, 512)
+ b = np.dsplit(phi_target_for_inversion, 3)
+ phi_target_for_inversion = np.stack(b).reshape((3, 512, 512))
+
+ # create phi target for the latent / label pass
+ phi_target = resize_and_crop_image(target_im, opt_dims)
+ phi_target = np.expand_dims(phi_target, 0)
+ phi_target = np.repeat(phi_target, batch_size, axis=0)
+
+ # IMPORTANT: initialize variables before running the session
+ sess.run(tf.compat.v1.global_variables_initializer())
+ sess.run(tf.compat.v1.tables_initializer())
+
+ feed_dict = {
+ target: phi_target,
+ }
+
+ try:
+ print("Preparing to iterate...")
+ for i in range(opt_steps):
+ curr_loss, _, _ = sess.run([loss, train_step_z, train_step_y], feed_dict=feed_dict)
+
+ if i % 20 == 0:
+ phi_guess = sess.run(output)
+ guess_im = imgrid(imconvert_uint8(phi_guess), cols=1)
+ imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(int(i / 20))), guess_im)
+ print('iter: {}, loss: {}'.format(i, curr_loss))
+ except KeyboardInterrupt:
+ pass
+
+ z_guess, y_guess = sess.run([input_z, input_y])
+ out_images[index] = phi_target_for_inversion
+ out_labels[index] = y_guess
+ out_latent[index] = z_guess
+ return fp_frames
+
+def export_video(fp_frames):
+ print("Exporting video...")
+ cmd = [
+ '/home/lens/bin/ffmpeg',
+ '-y', # '-v', 'quiet',
+ '-r', '30',
+ '-i', join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_%04d.png'),
+ '-pix_fmt', 'yuv420p',
+ join(app_cfg.DIR_OUTPUTS, fp_frames + '.mp4')
+ ]
+ # print(' '.join(cmd))
+ call(cmd)
+ shutil.rmtree(join(app_cfg.DIR_OUTPUTS, fp_frames))
+
+def load_target_image(opt_fp_in):
+ print("Loading {}".format(opt_fp_in))
+ fn = os.path.basename(opt_fp_in)
+ fbase, ext = os.path.splitext(fn)
+ fp_frames = "frames_{}_{}".format(fbase, int(time.time() * 1000))
+ fp_frames_fullpath = join(app_cfg.DIR_OUTPUTS, fp_frames)
+ print("Output to {}".format(fp_frames_fullpath))
+ os.makedirs(fp_frames_fullpath, exist_ok=True)
+ target_im = imread(opt_fp_in)
+ return target_im, fp_frames