summaryrefslogtreecommitdiff
path: root/cli/app/search/search_km.py
diff options
context:
space:
mode:
Diffstat (limited to 'cli/app/search/search_km.py')
-rw-r--r--cli/app/search/search_km.py86
1 files changed, 0 insertions, 86 deletions
diff --git a/cli/app/search/search_km.py b/cli/app/search/search_km.py
deleted file mode 100644
index bdffbe4..0000000
--- a/cli/app/search/search_km.py
+++ /dev/null
@@ -1,86 +0,0 @@
-import cStringIO
-import numpy as np
-import PIL.Image
-from scipy.stats import truncnorm
-import tensorflow as tf
-import tensorflow_hub as hub
-import cv2
-
-module_path = 'https://tfhub.dev/deepmind/biggan-128/2' # 128x128 BigGAN
-# module_path = 'https://tfhub.dev/deepmind/biggan-256/2' # 256x256 BigGAN
-# module_path = 'https://tfhub.dev/deepmind/biggan-512/2' # 512x512 BigGAN
-
-tf.reset_default_graph()
-module = hub.Module(module_path)
-inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
- for k, v in module.get_input_info_dict().iteritems()}
-output = module(inputs)
-
-input_z = inputs['z']
-input_y = inputs['y']
-input_trunc = inputs['truncation']
-
-dim_z = input_z.shape.as_list()[1]
-vocab_size = input_y.shape.as_list()[1]
-
-initializer = tf.global_variables_initializer()
-sess = tf.Session()
-sess.run(initializer)
-
-y = 259 # pomeranian
-n_samples = 9
-truncation = 0.5
-
-# phi_target = imread(uploaded.keys()[0])
-# phi_target = imconvert_float32(phi_target)
-# phi_target = np.expand_dims(phi_target, 0)
-# phi_target = phi_target[:128,:128]
-# phi_target = np.repeat(phi_target, n_samples, axis=0)
-
-label = one_hot([y] * n_samples, vocab_size)
-
-# use z from manifold
-if uploaded is not None:
- z_target = np.repeat(truncated_z_sample(1, truncation, 0), n_samples, axis=0)
- feed_dict = {input_z: z_target, input_y: label, input_trunc: truncation}
- phi_target = sess.run(output, feed_dict=feed_dict)
-
-target_im = imgrid(imconvert_uint8(phi_target), cols=3)
-cost = tf.reduce_sum(tf.pow(output - phi_target, 2))
-dc_dz, = tf.gradients(cost, [input_z])
-
-lr = 0.0001
-z_guess = np.asarray(truncated_z_sample(n_samples, truncation/2, 1))
-feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation}
-phi_impostor = sess.run(output, feed_dict=feed_dict)
-impostor_im = imgrid(imconvert_uint8(phi_impostor), cols=3)
-comparison = None
-
-try:
- for i in range(1000):
- feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation}
- grad = dc_dz.eval(session=sess, feed_dict=feed_dict)
- z_guess -= grad * lr
-
- # decay/attenuate learning rate to 0.05 of the original over 1000 frames
- lr *= 0.997
-
- indices = np.logical_or(z_guess <= -2*truncation, z_guess >= +2*truncation)
- z_guess[indices] = np.random.randn(np.count_nonzero(indices))
-
- feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation}
- phi_guess = sess.run(output, feed_dict=feed_dict)
- guess_im = imgrid(imconvert_uint8(phi_guess), cols=3)
-
- imwrite('frames/{:06d}.png'.format(i), guess_im)
-
- # display the progress every 10 frames
- if i % 10 == 0:
- comparison = imgrid(np.asarray([impostor_im, guess_im, target_im]), cols=3, pad=10)
-
- # clear_output(wait=True)
- print('lr: {}, iter: {}, grad_std: {}'.format(lr, i, np.std(grad)))
- imshow(comparison, format='jpeg')
-except KeyboardInterrupt:
- pass
-