import cStringIO import numpy as np import PIL.Image from scipy.stats import truncnorm import tensorflow as tf import tensorflow_hub as hub import cv2 module_path = 'https://tfhub.dev/deepmind/biggan-128/2' # 128x128 BigGAN # module_path = 'https://tfhub.dev/deepmind/biggan-256/2' # 256x256 BigGAN # module_path = 'https://tfhub.dev/deepmind/biggan-512/2' # 512x512 BigGAN tf.reset_default_graph() module = hub.Module(module_path) inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k) for k, v in module.get_input_info_dict().iteritems()} output = module(inputs) input_z = inputs['z'] input_y = inputs['y'] input_trunc = inputs['truncation'] dim_z = input_z.shape.as_list()[1] vocab_size = input_y.shape.as_list()[1] initializer = tf.global_variables_initializer() sess = tf.Session() sess.run(initializer) y = 259 # pomeranian n_samples = 9 truncation = 0.5 # phi_target = imread(uploaded.keys()[0]) # phi_target = imconvert_float32(phi_target) # phi_target = np.expand_dims(phi_target, 0) # phi_target = phi_target[:128,:128] # phi_target = np.repeat(phi_target, n_samples, axis=0) label = one_hot([y] * n_samples, vocab_size) # use z from manifold if uploaded is not None: z_target = np.repeat(truncated_z_sample(1, truncation, 0), n_samples, axis=0) feed_dict = {input_z: z_target, input_y: label, input_trunc: truncation} phi_target = sess.run(output, feed_dict=feed_dict) target_im = imgrid(imconvert_uint8(phi_target), cols=3) cost = tf.reduce_sum(tf.pow(output - phi_target, 2)) dc_dz, = tf.gradients(cost, [input_z]) lr = 0.0001 z_guess = np.asarray(truncated_z_sample(n_samples, truncation/2, 1)) feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation} phi_impostor = sess.run(output, feed_dict=feed_dict) impostor_im = imgrid(imconvert_uint8(phi_impostor), cols=3) comparison = None try: for i in range(1000): feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation} grad = dc_dz.eval(session=sess, feed_dict=feed_dict) z_guess -= grad * lr # decay/attenuate learning rate to 0.05 of the original over 1000 frames lr *= 0.997 indices = np.logical_or(z_guess <= -2*truncation, z_guess >= +2*truncation) z_guess[indices] = np.random.randn(np.count_nonzero(indices)) feed_dict = {input_z: z_guess, input_y: label, input_trunc: truncation} phi_guess = sess.run(output, feed_dict=feed_dict) guess_im = imgrid(imconvert_uint8(phi_guess), cols=3) imwrite('frames/{:06d}.png'.format(i), guess_im) # display the progress every 10 frames if i % 10 == 0: comparison = imgrid(np.asarray([impostor_im, guess_im, target_im]), cols=3, pad=10) # clear_output(wait=True) print('lr: {}, iter: {}, grad_std: {}'.format(lr, i, np.std(grad))) imshow(comparison, format='jpeg') except KeyboardInterrupt: pass