summaryrefslogtreecommitdiff
path: root/cli/app/search/search_km.py
blob: bdffbe43641f8ed117380a3c4b8bede6edfbb11e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import cStringIO
import numpy as np
import PIL.Image
from scipy.stats import truncnorm
import tensorflow as tf
import tensorflow_hub as hub
import cv2

module_path = 'https://tfhub.dev/deepmind/biggan-128/2'  # 128x128 BigGAN
# module_path = 'https://tfhub.dev/deepmind/biggan-256/2'  # 256x256 BigGAN
# module_path = 'https://tfhub.dev/deepmind/biggan-512/2'  # 512x512 BigGAN

tf.reset_default_graph()
module = hub.Module(module_path)
inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
          for k, v in module.get_input_info_dict().iteritems()}
output = module(inputs)

input_z = inputs['z']
input_y = inputs['y']
input_trunc = inputs['truncation']

dim_z = input_z.shape.as_list()[1]
vocab_size = input_y.shape.as_list()[1]

initializer = tf.global_variables_initializer()
sess = tf.Session()
sess.run(initializer)

y = 259 # pomeranian
n_samples = 9
truncation = 0.5

# phi_target = imread(uploaded.keys()[0])
# phi_target = imconvert_float32(phi_target)
# phi_target = np.expand_dims(phi_target, 0)
# phi_target = phi_target[:128,:128]
# phi_target = np.repeat(phi_target, n_samples, axis=0)

label = one_hot([y] * n_samples, vocab_size)

# use z from manifold
if uploaded is not None:
  z_target = np.repeat(truncated_z_sample(1, truncation, 0), n_samples, axis=0)
  feed_dict = {input_z: z_target, input_y: label, input_trunc: truncation}
  phi_target = sess.run(output, feed_dict=feed_dict)

target_im = imgrid(imconvert_uint8(phi_target), cols=3)
cost = tf.reduce_sum(tf.pow(output - phi_target, 2))
dc_dz, = tf.gradients(cost, [input_z])

lr = 0.0001
z_guess = np.asarray(truncated_z_sample(n_samples, truncation/2, 1))
feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation}
phi_impostor = sess.run(output, feed_dict=feed_dict)
impostor_im = imgrid(imconvert_uint8(phi_impostor), cols=3)
comparison = None

try:
  for i in range(1000):
    feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation}
    grad = dc_dz.eval(session=sess, feed_dict=feed_dict)
    z_guess -= grad * lr

    # decay/attenuate learning rate to 0.05 of the original over 1000 frames
    lr *= 0.997

    indices = np.logical_or(z_guess <= -2*truncation, z_guess >= +2*truncation)
    z_guess[indices] = np.random.randn(np.count_nonzero(indices))

    feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation}
    phi_guess = sess.run(output, feed_dict=feed_dict)
    guess_im = imgrid(imconvert_uint8(phi_guess), cols=3)

    imwrite('frames/{:06d}.png'.format(i), guess_im)

    # display the progress every 10 frames
    if i % 10 == 0:
      comparison = imgrid(np.asarray([impostor_im, guess_im, target_im]), cols=3, pad=10)

      # clear_output(wait=True)
      print('lr: {}, iter: {}, grad_std: {}'.format(lr, i, np.std(grad)))
      imshow(comparison, format='jpeg')
except KeyboardInterrupt:
  pass