diff options
Diffstat (limited to 'cli/app/search/search_km.py')
| -rw-r--r-- | cli/app/search/search_km.py | 86 |
1 files changed, 0 insertions, 86 deletions
diff --git a/cli/app/search/search_km.py b/cli/app/search/search_km.py deleted file mode 100644 index bdffbe4..0000000 --- a/cli/app/search/search_km.py +++ /dev/null @@ -1,86 +0,0 @@ -import cStringIO -import numpy as np -import PIL.Image -from scipy.stats import truncnorm -import tensorflow as tf -import tensorflow_hub as hub -import cv2 - -module_path = 'https://tfhub.dev/deepmind/biggan-128/2' # 128x128 BigGAN -# module_path = 'https://tfhub.dev/deepmind/biggan-256/2' # 256x256 BigGAN -# module_path = 'https://tfhub.dev/deepmind/biggan-512/2' # 512x512 BigGAN - -tf.reset_default_graph() -module = hub.Module(module_path) -inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k) - for k, v in module.get_input_info_dict().iteritems()} -output = module(inputs) - -input_z = inputs['z'] -input_y = inputs['y'] -input_trunc = inputs['truncation'] - -dim_z = input_z.shape.as_list()[1] -vocab_size = input_y.shape.as_list()[1] - -initializer = tf.global_variables_initializer() -sess = tf.Session() -sess.run(initializer) - -y = 259 # pomeranian -n_samples = 9 -truncation = 0.5 - -# phi_target = imread(uploaded.keys()[0]) -# phi_target = imconvert_float32(phi_target) -# phi_target = np.expand_dims(phi_target, 0) -# phi_target = phi_target[:128,:128] -# phi_target = np.repeat(phi_target, n_samples, axis=0) - -label = one_hot([y] * n_samples, vocab_size) - -# use z from manifold -if uploaded is not None: - z_target = np.repeat(truncated_z_sample(1, truncation, 0), n_samples, axis=0) - feed_dict = {input_z: z_target, input_y: label, input_trunc: truncation} - phi_target = sess.run(output, feed_dict=feed_dict) - -target_im = imgrid(imconvert_uint8(phi_target), cols=3) -cost = tf.reduce_sum(tf.pow(output - phi_target, 2)) -dc_dz, = tf.gradients(cost, [input_z]) - -lr = 0.0001 -z_guess = np.asarray(truncated_z_sample(n_samples, truncation/2, 1)) -feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation} -phi_impostor = sess.run(output, feed_dict=feed_dict) -impostor_im = imgrid(imconvert_uint8(phi_impostor), cols=3) -comparison = None - -try: - for i in range(1000): - feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation} - grad = dc_dz.eval(session=sess, feed_dict=feed_dict) - z_guess -= grad * lr - - # decay/attenuate learning rate to 0.05 of the original over 1000 frames - lr *= 0.997 - - indices = np.logical_or(z_guess <= -2*truncation, z_guess >= +2*truncation) - z_guess[indices] = np.random.randn(np.count_nonzero(indices)) - - feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation} - phi_guess = sess.run(output, feed_dict=feed_dict) - guess_im = imgrid(imconvert_uint8(phi_guess), cols=3) - - imwrite('frames/{:06d}.png'.format(i), guess_im) - - # display the progress every 10 frames - if i % 10 == 0: - comparison = imgrid(np.asarray([impostor_im, guess_im, target_im]), cols=3, pad=10) - - # clear_output(wait=True) - print('lr: {}, iter: {}, grad_std: {}'.format(lr, i, np.std(grad))) - imshow(comparison, format='jpeg') -except KeyboardInterrupt: - pass - |
