summaryrefslogtreecommitdiff
path: root/inversion/interpolation.py
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-08 21:43:30 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-08 21:43:30 +0100
commitfb70ab05768fa4a54358dc1f304b68bc7aff6dae (patch)
tree6ba4c805ce37b5b8827b08946f0b22f639fa3e14 /inversion/interpolation.py
parent326db345db13b1ab3a76406644654cb78b4d1b8d (diff)
inversion json files
Diffstat (limited to 'inversion/interpolation.py')
-rw-r--r--inversion/interpolation.py190
1 files changed, 190 insertions, 0 deletions
diff --git a/inversion/interpolation.py b/inversion/interpolation.py
new file mode 100644
index 0000000..7133337
--- /dev/null
+++ b/inversion/interpolation.py
@@ -0,0 +1,190 @@
+# ------------------------------------------------------------------------------
+# Linear interpolation between inverted images and generated images.
+# ------------------------------------------------------------------------------
+
+import functools
+import h5py
+import itertools
+import numpy as np
+import os
+import pickle
+import params
+import scipy
+import sys
+import tensorflow as tf
+import tensorflow_hub as hub
+import time
+import visualize as vs
+
+# --------------------------
+# Hyper-parameters.
+# --------------------------
+if len(sys.argv) < 2:
+ sys.exit('Must provide a configuration file.')
+params = params.Params(sys.argv[1])
+
+# --------------------------
+# Global variables.
+# --------------------------
+BATCH_SIZE = params.batch_size
+SAMPLE_SIZE = params.sample_size
+SAMPLES_DIR = 'interpolation'
+INVERSES_DIR = 'inverses'
+if not os.path.exists(SAMPLES_DIR):
+ os.makedirs(SAMPLES_DIR)
+if not os.path.exists(INVERSES_DIR):
+ os.makedirs(INVERSES_DIR)
+
+# --------------------------
+# Util functions.
+# --------------------------
+def interpolate(A, B, num_interps):
+ alphas = np.linspace(0, 1., num_interps)
+ if A.shape != B.shape:
+ raise ValueError('A and B must have the same shape to interpolate.')
+ return np.array([(1-a)*A + a*B for a in alphas])
+
+# One hot encoding for classes.
+def one_hot(values):
+ return np.eye(N_CLASS)[values]
+
+# Random sampler for classes.
+def label_sampler(size=[BATCH_SIZE]):
+ return np.random.random_integers(low=0, high=N_CLASS-1, size=size)
+
+def label_hot_sampler(size=[BATCH_SIZE]):
+ return one_hot(label_sampler(size=size))
+
+# --------------------------
+# Load Graph.
+# --------------------------
+generator = hub.Module(str(params.generator_path))
+
+gen_signature = 'generator'
+if 'generator' not in generator.get_signature_names():
+ gen_signature = 'default'
+
+input_info = generator.get_input_info_dict(gen_signature)
+COND_GAN = 'y' in input_info
+
+if COND_GAN:
+ Z_DIM = input_info['z'].get_shape().as_list()[1]
+ latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+ N_CLASS = input_info['y'].get_shape().as_list()[1]
+ label = tf.get_variable(name='label', dtype=tf.float32,
+ shape=[BATCH_SIZE, N_CLASS])
+ gen_in = dict(params.generator_fixed_inputs)
+ gen_in['z'] = latent
+ gen_in['y'] = label
+ gen_img = generator(gen_in, signature=gen_signature)
+else:
+ Z_DIM = input_info['default'].get_shape().as_list()[1]
+ latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+ if (params.generator_fixed_inputs):
+ gen_in = dict(params.generator_fixed_inputs)
+ gen_in['z'] = latent
+ gen_img = generator(gen_in, signature=gen_signature)
+ else:
+ gen_img = generator(latent, signature=gen_signature)
+
+# Convert generated image to channels_first.
+gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
+
+# Override intermediate layer.
+if params.inv_layer == 'latent':
+ encoding = latent
+ ENC_SHAPE = [Z_DIM]
+else:
+ layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer
+ gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name)
+ ENC_SHAPE = gen_encoding.get_shape().as_list()[1:]
+ encoding = tf.get_variable(name='encoding', dtype=tf.float32,
+ shape=[BATCH_SIZE,] + ENC_SHAPE)
+ tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding))
+
+# Define image shape.
+IMG_SHAPE = gen_img.get_shape().as_list()[1:]
+
+# --------------------------
+# Noise source.
+# --------------------------
+def noise_sampler():
+ return np.random.normal(size=[BATCH_SIZE, Z_DIM])
+
+# --------------------------
+# Dataset.
+# --------------------------
+in_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'r')
+in_images = in_file['xtrain']
+if COND_GAN:
+ in_labels = in_file['ytrain']
+in_encoding = in_file['encoding']
+in_latent = in_file['latent']
+NUM_IMGS = in_images.shape[0] # number of images.
+
+# --------------------------
+# Training.
+# --------------------------
+# Start session.
+sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+sess.run(tf.global_variables_initializer())
+sess.run(tf.tables_initializer())
+
+for i in range(0, NUM_IMGS, BATCH_SIZE):
+ # Set label.
+ if COND_GAN:
+ sess.run(label.assign(one_hot(in_labels[i:i+BATCH_SIZE])))
+
+ # Linear interpolation between G_1(z*) and G_1(z*)+delta*.
+ sample_enc_1 = in_encoding[i:i+BATCH_SIZE]
+ out_batch = np.ndarray(shape=[8,BATCH_SIZE]+IMG_SHAPE, dtype='uint8')
+ out_batch[0] = in_images[i:i+BATCH_SIZE]
+ sess.run(latent.assign(in_latent[i:i+BATCH_SIZE]))
+ sample_enc_2 = sess.run(gen_encoding)
+ sample_enc = interpolate(sample_enc_1, sample_enc_2, 7)
+ for j in range(0,7):
+ sess.run(encoding.assign(sample_enc[j]))
+ gen_images = sess.run(gen_img)
+ gen_images = vs.data2img(gen_images)
+ out_batch[j+1] = gen_images
+
+ out_batch = np.transpose(out_batch, [1, 0, 2, 3, 4])
+ for k in range(BATCH_SIZE):
+ out_batch_k = vs.seq_transform(out_batch[k])
+ # Add white padding.
+ pad = 20
+ out_batch_kk = np.ndarray(shape=[IMG_SHAPE[1], IMG_SHAPE[1]*8+pad, 3],
+ dtype='uint8')
+ out_batch_kk[:,:IMG_SHAPE[1],:] = out_batch_k[:,:IMG_SHAPE[1],:]
+ out_batch_kk[:,IMG_SHAPE[1]:IMG_SHAPE[1]+pad,:] = 255
+ out_batch_kk[:,IMG_SHAPE[1]+pad:,:] = out_batch_k[:,IMG_SHAPE[1]:,:]
+
+ vs.save_image('{}/interpolation_delta_{}.png'.format(SAMPLES_DIR, i+k), out_batch_kk)
+ print('Saved delta interpolation for img: {}.'.format(i+k))
+
+ # Linear interpolation between G_1(z_random) and G_1(z*)+delta*.
+ sample_enc_1 = in_encoding[i:i+BATCH_SIZE]
+ sample_z_1 = in_latent[i:i+BATCH_SIZE]
+ out_batch = np.ndarray(shape=[8*8,BATCH_SIZE]+IMG_SHAPE, dtype='uint8')
+ for k in range(8):
+ sample_z_2 = noise_sampler()
+ sess.run(latent.assign(sample_z_2))
+ sample_enc_2 = sess.run(gen_encoding)
+ sample_z = interpolate(sample_z_1, sample_z_2, 8)
+ sample_enc = interpolate(sample_enc_1, sample_enc_2, 8)
+ for j in range(8):
+ sess.run(latent.assign(sample_z[j]))
+ sess.run(encoding.assign(sample_enc[j]))
+ gen_images = sess.run(gen_img)
+ gen_images = vs.data2img(gen_images)
+ out_batch[k*8+j] = gen_images
+
+ out_batch = np.transpose(out_batch, [1, 0, 2, 3, 4])
+ for k in range(BATCH_SIZE):
+ out_batch_k = vs.grid_transform(out_batch[k])
+ vs.save_image('{}/interpolation_rand_{}.png'.format(SAMPLES_DIR, i+k), out_batch_k)
+ print('Saved rand interpolation for img: {}.'.format(i+k))
+
+sess.close()