# ------------------------------------------------------------------------------ # Linear interpolation between inverted images and generated images. # ------------------------------------------------------------------------------ import functools import h5py import itertools import numpy as np import os import pickle import params import scipy import sys import tensorflow as tf import tensorflow_hub as hub import time import visualize as vs # -------------------------- # Hyper-parameters. # -------------------------- if len(sys.argv) < 2: sys.exit('Must provide a configuration file.') params = params.Params(sys.argv[1]) # -------------------------- # Global variables. # -------------------------- BATCH_SIZE = params.batch_size SAMPLE_SIZE = params.sample_size SAMPLES_DIR = 'interpolation' INVERSES_DIR = 'inverses' if not os.path.exists(SAMPLES_DIR): os.makedirs(SAMPLES_DIR) if not os.path.exists(INVERSES_DIR): os.makedirs(INVERSES_DIR) # -------------------------- # Util functions. # -------------------------- def interpolate(A, B, num_interps): alphas = np.linspace(0, 1., num_interps) if A.shape != B.shape: raise ValueError('A and B must have the same shape to interpolate.') return np.array([(1-a)*A + a*B for a in alphas]) # One hot encoding for classes. def one_hot(values): return np.eye(N_CLASS)[values] # Random sampler for classes. def label_sampler(size=[BATCH_SIZE]): return np.random.random_integers(low=0, high=N_CLASS-1, size=size) def label_hot_sampler(size=[BATCH_SIZE]): return one_hot(label_sampler(size=size)) # -------------------------- # Load Graph. # -------------------------- generator = hub.Module(str(params.generator_path)) gen_signature = 'generator' if 'generator' not in generator.get_signature_names(): gen_signature = 'default' input_info = generator.get_input_info_dict(gen_signature) COND_GAN = 'y' in input_info if COND_GAN: Z_DIM = input_info['z'].get_shape().as_list()[1] latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) N_CLASS = input_info['y'].get_shape().as_list()[1] label = tf.get_variable(name='label', dtype=tf.float32, shape=[BATCH_SIZE, N_CLASS]) gen_in = dict(params.generator_fixed_inputs) gen_in['z'] = latent gen_in['y'] = label gen_img = generator(gen_in, signature=gen_signature) else: Z_DIM = input_info['default'].get_shape().as_list()[1] latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) if (params.generator_fixed_inputs): gen_in = dict(params.generator_fixed_inputs) gen_in['z'] = latent gen_img = generator(gen_in, signature=gen_signature) else: gen_img = generator(latent, signature=gen_signature) # Convert generated image to channels_first. gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) # Override intermediate layer. if params.inv_layer == 'latent': encoding = latent ENC_SHAPE = [Z_DIM] else: layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] encoding = tf.get_variable(name='encoding', dtype=tf.float32, shape=[BATCH_SIZE,] + ENC_SHAPE) tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) # Define image shape. IMG_SHAPE = gen_img.get_shape().as_list()[1:] # -------------------------- # Noise source. # -------------------------- def noise_sampler(): return np.random.normal(size=[BATCH_SIZE, Z_DIM]) # -------------------------- # Dataset. # -------------------------- in_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'r') in_images = in_file['xtrain'] if COND_GAN: in_labels = in_file['ytrain'] in_encoding = in_file['encoding'] in_latent = in_file['latent'] NUM_IMGS = in_images.shape[0] # number of images. # -------------------------- # Training. # -------------------------- # Start session. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) for i in range(0, NUM_IMGS, BATCH_SIZE): # Set label. if COND_GAN: sess.run(label.assign(one_hot(in_labels[i:i+BATCH_SIZE]))) # Linear interpolation between G_1(z*) and G_1(z*)+delta*. sample_enc_1 = in_encoding[i:i+BATCH_SIZE] out_batch = np.ndarray(shape=[8,BATCH_SIZE]+IMG_SHAPE, dtype='uint8') out_batch[0] = in_images[i:i+BATCH_SIZE] sess.run(latent.assign(in_latent[i:i+BATCH_SIZE])) sample_enc_2 = sess.run(gen_encoding) sample_enc = interpolate(sample_enc_1, sample_enc_2, 7) for j in range(0,7): sess.run(encoding.assign(sample_enc[j])) gen_images = sess.run(gen_img) gen_images = vs.data2img(gen_images) out_batch[j+1] = gen_images out_batch = np.transpose(out_batch, [1, 0, 2, 3, 4]) for k in range(BATCH_SIZE): out_batch_k = vs.seq_transform(out_batch[k]) # Add white padding. pad = 20 out_batch_kk = np.ndarray(shape=[IMG_SHAPE[1], IMG_SHAPE[1]*8+pad, 3], dtype='uint8') out_batch_kk[:,:IMG_SHAPE[1],:] = out_batch_k[:,:IMG_SHAPE[1],:] out_batch_kk[:,IMG_SHAPE[1]:IMG_SHAPE[1]+pad,:] = 255 out_batch_kk[:,IMG_SHAPE[1]+pad:,:] = out_batch_k[:,IMG_SHAPE[1]:,:] vs.save_image('{}/interpolation_delta_{}.png'.format(SAMPLES_DIR, i+k), out_batch_kk) print('Saved delta interpolation for img: {}.'.format(i+k)) # Linear interpolation between G_1(z_random) and G_1(z*)+delta*. sample_enc_1 = in_encoding[i:i+BATCH_SIZE] sample_z_1 = in_latent[i:i+BATCH_SIZE] out_batch = np.ndarray(shape=[8*8,BATCH_SIZE]+IMG_SHAPE, dtype='uint8') for k in range(8): sample_z_2 = noise_sampler() sess.run(latent.assign(sample_z_2)) sample_enc_2 = sess.run(gen_encoding) sample_z = interpolate(sample_z_1, sample_z_2, 8) sample_enc = interpolate(sample_enc_1, sample_enc_2, 8) for j in range(8): sess.run(latent.assign(sample_z[j])) sess.run(encoding.assign(sample_enc[j])) gen_images = sess.run(gen_img) gen_images = vs.data2img(gen_images) out_batch[k*8+j] = gen_images out_batch = np.transpose(out_batch, [1, 0, 2, 3, 4]) for k in range(BATCH_SIZE): out_batch_k = vs.grid_transform(out_batch[k]) vs.save_image('{}/interpolation_rand_{}.png'.format(SAMPLES_DIR, i+k), out_batch_k) print('Saved rand interpolation for img: {}.'.format(i+k)) sess.close()