diff options
Diffstat (limited to 'inversion/interpolation.py')
| -rw-r--r-- | inversion/interpolation.py | 190 |
1 files changed, 190 insertions, 0 deletions
diff --git a/inversion/interpolation.py b/inversion/interpolation.py new file mode 100644 index 0000000..7133337 --- /dev/null +++ b/inversion/interpolation.py @@ -0,0 +1,190 @@ +# ------------------------------------------------------------------------------ +# Linear interpolation between inverted images and generated images. +# ------------------------------------------------------------------------------ + +import functools +import h5py +import itertools +import numpy as np +import os +import pickle +import params +import scipy +import sys +import tensorflow as tf +import tensorflow_hub as hub +import time +import visualize as vs + +# -------------------------- +# Hyper-parameters. +# -------------------------- +if len(sys.argv) < 2: + sys.exit('Must provide a configuration file.') +params = params.Params(sys.argv[1]) + +# -------------------------- +# Global variables. +# -------------------------- +BATCH_SIZE = params.batch_size +SAMPLE_SIZE = params.sample_size +SAMPLES_DIR = 'interpolation' +INVERSES_DIR = 'inverses' +if not os.path.exists(SAMPLES_DIR): + os.makedirs(SAMPLES_DIR) +if not os.path.exists(INVERSES_DIR): + os.makedirs(INVERSES_DIR) + +# -------------------------- +# Util functions. +# -------------------------- +def interpolate(A, B, num_interps): + alphas = np.linspace(0, 1., num_interps) + if A.shape != B.shape: + raise ValueError('A and B must have the same shape to interpolate.') + return np.array([(1-a)*A + a*B for a in alphas]) + +# One hot encoding for classes. +def one_hot(values): + return np.eye(N_CLASS)[values] + +# Random sampler for classes. +def label_sampler(size=[BATCH_SIZE]): + return np.random.random_integers(low=0, high=N_CLASS-1, size=size) + +def label_hot_sampler(size=[BATCH_SIZE]): + return one_hot(label_sampler(size=size)) + +# -------------------------- +# Load Graph. +# -------------------------- +generator = hub.Module(str(params.generator_path)) + +gen_signature = 'generator' +if 'generator' not in generator.get_signature_names(): + gen_signature = 'default' + +input_info = generator.get_input_info_dict(gen_signature) +COND_GAN = 'y' in input_info + +if COND_GAN: + Z_DIM = input_info['z'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + N_CLASS = input_info['y'].get_shape().as_list()[1] + label = tf.get_variable(name='label', dtype=tf.float32, + shape=[BATCH_SIZE, N_CLASS]) + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_in['y'] = label + gen_img = generator(gen_in, signature=gen_signature) +else: + Z_DIM = input_info['default'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + if (params.generator_fixed_inputs): + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_img = generator(gen_in, signature=gen_signature) + else: + gen_img = generator(latent, signature=gen_signature) + +# Convert generated image to channels_first. +gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) + +# Override intermediate layer. +if params.inv_layer == 'latent': + encoding = latent + ENC_SHAPE = [Z_DIM] +else: + layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer + gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) + ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] + encoding = tf.get_variable(name='encoding', dtype=tf.float32, + shape=[BATCH_SIZE,] + ENC_SHAPE) + tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) + +# Define image shape. +IMG_SHAPE = gen_img.get_shape().as_list()[1:] + +# -------------------------- +# Noise source. +# -------------------------- +def noise_sampler(): + return np.random.normal(size=[BATCH_SIZE, Z_DIM]) + +# -------------------------- +# Dataset. +# -------------------------- +in_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'r') +in_images = in_file['xtrain'] +if COND_GAN: + in_labels = in_file['ytrain'] +in_encoding = in_file['encoding'] +in_latent = in_file['latent'] +NUM_IMGS = in_images.shape[0] # number of images. + +# -------------------------- +# Training. +# -------------------------- +# Start session. +sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) +sess.run(tf.global_variables_initializer()) +sess.run(tf.tables_initializer()) + +for i in range(0, NUM_IMGS, BATCH_SIZE): + # Set label. + if COND_GAN: + sess.run(label.assign(one_hot(in_labels[i:i+BATCH_SIZE]))) + + # Linear interpolation between G_1(z*) and G_1(z*)+delta*. + sample_enc_1 = in_encoding[i:i+BATCH_SIZE] + out_batch = np.ndarray(shape=[8,BATCH_SIZE]+IMG_SHAPE, dtype='uint8') + out_batch[0] = in_images[i:i+BATCH_SIZE] + sess.run(latent.assign(in_latent[i:i+BATCH_SIZE])) + sample_enc_2 = sess.run(gen_encoding) + sample_enc = interpolate(sample_enc_1, sample_enc_2, 7) + for j in range(0,7): + sess.run(encoding.assign(sample_enc[j])) + gen_images = sess.run(gen_img) + gen_images = vs.data2img(gen_images) + out_batch[j+1] = gen_images + + out_batch = np.transpose(out_batch, [1, 0, 2, 3, 4]) + for k in range(BATCH_SIZE): + out_batch_k = vs.seq_transform(out_batch[k]) + # Add white padding. + pad = 20 + out_batch_kk = np.ndarray(shape=[IMG_SHAPE[1], IMG_SHAPE[1]*8+pad, 3], + dtype='uint8') + out_batch_kk[:,:IMG_SHAPE[1],:] = out_batch_k[:,:IMG_SHAPE[1],:] + out_batch_kk[:,IMG_SHAPE[1]:IMG_SHAPE[1]+pad,:] = 255 + out_batch_kk[:,IMG_SHAPE[1]+pad:,:] = out_batch_k[:,IMG_SHAPE[1]:,:] + + vs.save_image('{}/interpolation_delta_{}.png'.format(SAMPLES_DIR, i+k), out_batch_kk) + print('Saved delta interpolation for img: {}.'.format(i+k)) + + # Linear interpolation between G_1(z_random) and G_1(z*)+delta*. + sample_enc_1 = in_encoding[i:i+BATCH_SIZE] + sample_z_1 = in_latent[i:i+BATCH_SIZE] + out_batch = np.ndarray(shape=[8*8,BATCH_SIZE]+IMG_SHAPE, dtype='uint8') + for k in range(8): + sample_z_2 = noise_sampler() + sess.run(latent.assign(sample_z_2)) + sample_enc_2 = sess.run(gen_encoding) + sample_z = interpolate(sample_z_1, sample_z_2, 8) + sample_enc = interpolate(sample_enc_1, sample_enc_2, 8) + for j in range(8): + sess.run(latent.assign(sample_z[j])) + sess.run(encoding.assign(sample_enc[j])) + gen_images = sess.run(gen_img) + gen_images = vs.data2img(gen_images) + out_batch[k*8+j] = gen_images + + out_batch = np.transpose(out_batch, [1, 0, 2, 3, 4]) + for k in range(BATCH_SIZE): + out_batch_k = vs.grid_transform(out_batch[k]) + vs.save_image('{}/interpolation_rand_{}.png'.format(SAMPLES_DIR, i+k), out_batch_k) + print('Saved rand interpolation for img: {}.'.format(i+k)) + +sess.close() |
