diff options
Diffstat (limited to 'cli/app/search/search_km.py')
| -rw-r--r-- | cli/app/search/search_km.py | 86 |
1 files changed, 86 insertions, 0 deletions
diff --git a/cli/app/search/search_km.py b/cli/app/search/search_km.py new file mode 100644 index 0000000..bdffbe4 --- /dev/null +++ b/cli/app/search/search_km.py @@ -0,0 +1,86 @@ +import cStringIO +import numpy as np +import PIL.Image +from scipy.stats import truncnorm +import tensorflow as tf +import tensorflow_hub as hub +import cv2 + +module_path = 'https://tfhub.dev/deepmind/biggan-128/2' # 128x128 BigGAN +# module_path = 'https://tfhub.dev/deepmind/biggan-256/2' # 256x256 BigGAN +# module_path = 'https://tfhub.dev/deepmind/biggan-512/2' # 512x512 BigGAN + +tf.reset_default_graph() +module = hub.Module(module_path) +inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k) + for k, v in module.get_input_info_dict().iteritems()} +output = module(inputs) + +input_z = inputs['z'] +input_y = inputs['y'] +input_trunc = inputs['truncation'] + +dim_z = input_z.shape.as_list()[1] +vocab_size = input_y.shape.as_list()[1] + +initializer = tf.global_variables_initializer() +sess = tf.Session() +sess.run(initializer) + +y = 259 # pomeranian +n_samples = 9 +truncation = 0.5 + +# phi_target = imread(uploaded.keys()[0]) +# phi_target = imconvert_float32(phi_target) +# phi_target = np.expand_dims(phi_target, 0) +# phi_target = phi_target[:128,:128] +# phi_target = np.repeat(phi_target, n_samples, axis=0) + +label = one_hot([y] * n_samples, vocab_size) + +# use z from manifold +if uploaded is not None: + z_target = np.repeat(truncated_z_sample(1, truncation, 0), n_samples, axis=0) + feed_dict = {input_z: z_target, input_y: label, input_trunc: truncation} + phi_target = sess.run(output, feed_dict=feed_dict) + +target_im = imgrid(imconvert_uint8(phi_target), cols=3) +cost = tf.reduce_sum(tf.pow(output - phi_target, 2)) +dc_dz, = tf.gradients(cost, [input_z]) + +lr = 0.0001 +z_guess = np.asarray(truncated_z_sample(n_samples, truncation/2, 1)) +feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation} +phi_impostor = sess.run(output, feed_dict=feed_dict) +impostor_im = imgrid(imconvert_uint8(phi_impostor), cols=3) +comparison = None + +try: + for i in range(1000): + feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation} + grad = dc_dz.eval(session=sess, feed_dict=feed_dict) + z_guess -= grad * lr + + # decay/attenuate learning rate to 0.05 of the original over 1000 frames + lr *= 0.997 + + indices = np.logical_or(z_guess <= -2*truncation, z_guess >= +2*truncation) + z_guess[indices] = np.random.randn(np.count_nonzero(indices)) + + feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation} + phi_guess = sess.run(output, feed_dict=feed_dict) + guess_im = imgrid(imconvert_uint8(phi_guess), cols=3) + + imwrite('frames/{:06d}.png'.format(i), guess_im) + + # display the progress every 10 frames + if i % 10 == 0: + comparison = imgrid(np.asarray([impostor_im, guess_im, target_im]), cols=3, pad=10) + + # clear_output(wait=True) + print('lr: {}, iter: {}, grad_std: {}'.format(lr, i, np.std(grad))) + imshow(comparison, format='jpeg') +except KeyboardInterrupt: + pass + |
