summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-x.gitignore4
-rw-r--r--inversion/LICENSE21
-rw-r--r--inversion/README.md43
-rw-r--r--inversion/image_sample.py295
-rw-r--r--inversion/interpolation.py190
-rw-r--r--inversion/inversion.py477
-rw-r--r--inversion/params.py27
-rw-r--r--inversion/params_dense-512.json40
-rw-r--r--inversion/params_dense.json40
-rw-r--r--inversion/params_latent-512.json40
-rw-r--r--inversion/params_latent.json40
-rw-r--r--inversion/random_sample-512.json12
-rw-r--r--inversion/random_sample.json12
-rw-r--r--inversion/random_sample.py144
-rw-r--r--inversion/segmentation.py191
-rw-r--r--inversion/visualize.py88
16 files changed, 1664 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
index 0a574af..5e4b572 100755
--- a/.gitignore
+++ b/.gitignore
@@ -64,3 +64,7 @@ data_store/imagenet/imagenet_images/
*.zip
+*.hdf5
+
+events.out.*
+
diff --git a/inversion/LICENSE b/inversion/LICENSE
new file mode 100644
index 0000000..b868d12
--- /dev/null
+++ b/inversion/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2019 MarcosPividori
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/inversion/README.md b/inversion/README.md
new file mode 100644
index 0000000..3be7b8d
--- /dev/null
+++ b/inversion/README.md
@@ -0,0 +1,43 @@
+Exploiting GAN Internal Capacity for High-Quality Reconstruction of Natural Images
+==================================================================================
+
+Code for reproducing experiments in ["Exploiting GAN Internal Capacity for High-Quality Reconstruction of Natural Images"](https://arxiv.org/abs/1911.05630)
+
+This directory contains associated source code to invert BigGAN generator for
+128x128 resolution. Requires Tensorflow.
+
+## Generation of Random Samples:
+Generate 1000 random samples of BigGAN generator:
+```console
+ $> python random_sample.py random_sample.json
+```
+
+## Inversion of the Generator:
+The optimization is split into two steps according to the paper:
+First step, invesion to the latent space:
+```console
+ $> python inversion.py params_latent.json
+```
+
+Second step, inversion to the dense layer:
+```console
+ $> python inversion.py params_dense.json
+```
+
+## Interpolation:
+Generate interpolations between the inverted images and generated images:
+```console
+ $> python interpolation.py params_dense.json
+```
+
+## Segmentation:
+Segment inverted images by clustering the attention map:
+```console
+ $> python segmentation.py params_dense.json
+```
+
+Note: to replicate the experiments on real images from ImageNet, first
+a hdf5 file must be created with random images from the dataset, similar to the
+procedure in "random_sample.py". Then, the two step of optimization must be
+executed (modify the "dataset:" parameter in params_latent.json to consider
+custom datasets).
diff --git a/inversion/image_sample.py b/inversion/image_sample.py
new file mode 100644
index 0000000..83622a1
--- /dev/null
+++ b/inversion/image_sample.py
@@ -0,0 +1,295 @@
+# ------------------------------------------------------------------------------
+# Generate random samples of the generator and save the images to a hdf5 file.
+# ------------------------------------------------------------------------------
+
+import h5py
+import numpy as np
+import os
+import sys
+import tensorflow as tf
+import tensorflow_hub as hub
+import time
+import visualize as vs
+import argparse
+from glob import glob
+
+# --------------------------
+# Hyper-parameters.
+# --------------------------
+# Expected parameters:
+# generator_path: path to generator module.
+# generator_fixed_inputs: dictionary of fixed generator's input parameters.
+# dataset_out: name for the output created dataset (hdf5 file).
+# General parameters:
+# batch_size: number of images generated at the same time.
+# random_label: choose random labels.
+# num_imgs: number of instances to generate.
+# custom_label: custom label to be fixed.
+# Logging:
+# sample_size: number of images included in sampled images.
+if len(sys.argv) < 2:
+ sys.exit('Must provide a configuration file.')
+
+parser = argparse.ArgumentParser(description='Initialize the image search.')
+parser.add_argument('--input_dir', required=True, help='Input directory of images')
+parser.add_argument('--tag', default=str(int(time.time())), help='Tag this build')
+params = parser.parse_args()
+
+# --------------------------
+# Hyper-parameters.
+# --------------------------
+# General parameters.
+BATCH_SIZE = 20
+SAMPLE_SIZE = 20
+assert SAMPLE_SIZE <= BATCH_SIZE
+
+INVERSION_ITERATIONS = 1000
+
+# --------------------------
+# Global directories.
+# --------------------------
+DATASET_OUT = "{}_dataset.hdf5".format(params.tag)
+SAMPLES_DIR = './outputs/{}/samples'.format(params.tag)
+INVERSES_DIR = './outputs/{}/inverses'.format(params.tag)
+os.makedirs(SAMPLES_DIR, exist_ok=True)
+os.makedirs(INVERSES_DIR, exist_ok=True)
+
+# --------------------------
+# Load Graph.
+# --------------------------
+generator = hub.Module('https://tfhub.dev/deepmind/biggan-128/2')
+
+gen_signature = 'generator'
+if 'generator' not in generator.get_signature_names():
+ gen_signature = 'default'
+
+input_info = generator.get_input_info_dict(gen_signature)
+
+Z_DIM = input_info['z'].get_shape().as_list()[1]
+latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+N_CLASS = input_info['y'].get_shape().as_list()[1]
+label = tf.get_variable(name='label', dtype=tf.float32,
+ shape=[BATCH_SIZE, N_CLASS])
+gen_in = {}
+gen_in['truncation'] = 1.0
+gen_in['z'] = latent
+gen_in['y'] = label
+gen_img = generator(gen_in, signature=gen_signature)
+
+# Convert generated image to channels_first.
+gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
+
+# Override intermediate layer.
+encoding = latent
+ENC_SHAPE = [Z_DIM]
+
+# Step counter.
+inv_step = tf.get_variable('inv_step', initializer=0, trainable=False)
+
+# Define target image.
+IMG_SHAPE = gen_img.get_shape().as_list()[1:]
+target = tf.get_variable(name='target', dtype=tf.int32,
+ shape=[BATCH_SIZE,] + IMG_SHAPE)
+target_img = (tf.cast(target, tf.float32) / 255.) * 2.0 - 1. # Norm to [-1, 1].
+
+# Monitor relu's activation.
+gen_scope = 'module_apply_' + gen_signature + '/'
+activation_rate = 1.0 - tf.nn.zero_fraction(tf.get_default_graph()\
+ .get_tensor_by_name(gen_scope + params.log_activation_layer))
+
+# --------------------------
+# Reconstruction losses.
+# --------------------------
+# Mse loss for image comparison.
+pix_square_diff = tf.square((target_img - gen_img) / 2.0)
+mse_loss = tf.reduce_mean(pix_square_diff)
+img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3])
+
+# Use custom features for image comparison.
+feature_extractor = hub.Module("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1")
+
+# Convert images from range [-1, 1] channels_first to [0, 1] channels_last.
+gen_img_1 = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1])
+target_img_1 = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1])
+
+# Convert images to appropriate size for feature extraction.
+height, width = hub.get_expected_image_size(feature_extractor)
+gen_img_1 = tf.image.resize_images(gen_img_1, [height, width])
+target_img_1 = tf.image.resize_images(target_img_1, [height, width])
+
+gen_feat = feature_extractor(dict(images=gen_img_1), as_dict=True,
+ signature='image_feature_vector')[params.feature_extractor_output]
+target_feat = feature_extractor(dict(images=target_img_1), as_dict=True,
+ signature='image_feature_vector')[params.feature_extractor_output]
+feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat),
+ [BATCH_SIZE, -1])
+feat_loss = tf.reduce_mean(feat_square_diff)
+img_feat_err = tf.reduce_mean(feat_square_diff, axis=1)
+
+# --------------------------
+# Regularization losses.
+# --------------------------
+norm_dist = tfp.distributions.Normal(0.0, 1.0)
+likeli_loss = - tf.reduce_mean(norm_dist.log_prob(latent))
+mode_log_prob = norm_dist.log_prob(0.0)
+likeli_loss += mode_log_prob
+
+# Per image reconstruction error.
+img_rec_err = 1.0 * img_mse_err + 1.0 * img_feat_err
+
+# Batch reconstruction error.
+rec_loss = 1.0 * mse_loss + 1.0 * feat_loss
+
+# Total inversion loss.
+inv_loss = rec_loss + 0.1 * likeli_loss
+
+# --------------------------
+# Optimizer.
+# --------------------------
+lrate = tf.train.exponential_decay(0.1, inv_step,
+ INVERSION_ITERATIONS / 2, 0.1, staircase=True)
+trained_params = [latent, label]
+optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999)
+inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params,
+ global_step=inv_step)
+reinit_optimizer = tf.variables_initializer(optimizer.variables())
+
+# --------------------------
+# Noise source.
+# --------------------------
+def noise_sampler():
+ return np.random.normal(size=[BATCH_SIZE, Z_DIM])
+
+def label_init(shape=[BATCH_SIZE, N_CLASS]):
+ return np.random.uniform(low=0.00, high=0.01, size=shape)
+
+# --------------------------
+# Dataset.
+# --------------------------
+
+paths = glob(os.path.join(params.input_dir, '*.jpg')) + \
+ glob(os.path.join(params.input_dir, '*.jpeg')) + \
+ glob(os.path.join(params.input_dir, '*.png'))
+sample_images = [ vs.load_image(fn, 128) for fn in sorted(paths) ]
+ACTUAL_NUM_IMGS = sample_images.shape[0] # number of images to be inverted.
+print("Number of images: {}".format(ACTUAL_NUM_IMGS))
+NUM_IMGS = ACTUAL_NUM_IMGS
+
+# pad the image array to match the batch size
+while NUM_IMGS % BATCH_SIZE != 0:
+ sample_images += sample_images[-1]
+ NUM_IMGS += 1
+sample_images = np.array(sample_images)
+
+def sample_images_gen():
+ for i in range(int(NUM_IMGS / BATCH_SIZE)):
+ i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE
+ yield sample_images[i_1:i_2]
+image_gen = sample_images_gen()
+
+assert(NUM_IMGS % BATCH_SIZE == 0)
+
+# --------------------------
+# Generation.
+# --------------------------
+# Start session.
+sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+sess.run(tf.global_variables_initializer())
+sess.run(tf.tables_initializer())
+
+# Output file.
+out_file = h5py.File(os.path.join(INVERSES_DIR, DATASET_OUT), 'w')
+out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, dtype='uint8')
+out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32')
+
+# Gradient descent w.r.t. generator's inputs.
+it = 0
+out_pos = 0
+start_time = time.time()
+
+for image_batch in image_gen:
+ # Set target.
+ sess.run(target.assign(image_batch))
+
+ # Start with a random label
+ label_batch = label_sampler(BATCH_SIZE)
+ sess.run(label.assign(label_init()))
+
+ # Start with a random vector
+ sess.run(latent.assign(noise_sampler()))
+
+ # Init optimizer.
+ sess.run(inv_step.assign(0))
+ sess.run(reinit_optimizer)
+
+ # Main optimization loop.
+ for _ in range(params.inv_it):
+ _inv_loss, _mse_loss, _feat_loss, _rec_loss, _likeli_loss,\
+ _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss,
+ rec_loss, likeli_loss, lrate, inv_train_op])
+
+ # Every 100 iterations save logs with training information.
+ if it % 100 == 0:
+ # Log losses.
+ etime = time.time() - start_time
+
+ _act_rate = sess.run(activation_rate)
+ print('activation_rate={:.4f}'.format(_act_rate))
+ log_stats('activation rate', _act_rate, it)
+
+ sys.stdout.flush()
+
+ # Log tensorboard's statistics.
+ log_stats('total loss', _inv_loss, it)
+ log_stats('mse loss', _mse_loss, it)
+ log_stats('feat loss', _feat_loss, it)
+ log_stats('rec loss', _rec_loss, it)
+ log_stats('reg loss', _reg_loss, it)
+ log_stats('dist loss', _dist_loss, it)
+ log_stats('out pos', out_pos, it)
+ log_stats('lrate', _lrate, it)
+ summary_writer.flush()
+
+ gen_images = sess.run(gen_img)
+ inv_batch = vs.interleave(image_batch[BATCH_SIZE - SAMPLE_SIZE:],
+ vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:]))
+ inv_batch = vs.grid_transform(inv_batch)
+ vs.save_image('{}/progress_{}.png'.format(SAMPLES_DIR, it), inv_batch)
+
+ it += 1
+
+ # gen_images = sess.run(gen_img)
+ # gen_images = vs.data2img(gen_images)
+ label_batch = sess.run(label)
+ print(label_batch.shape)
+
+ out_images[i:i+BATCH_SIZE] = image_batch
+ out_labels[i:i+BATCH_SIZE] = label_batch
+
+ out_batch = vs.grid_transform(gen_images[:SAMPLE_SIZE])
+ vs.save_image('{}/generated_{}.png'.format(SAMPLES_DIR, i), out_batch)
+ print('Saved samples for imgs: {}-{}.'.format(i,i+BATCH_SIZE))
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/inversion/interpolation.py b/inversion/interpolation.py
new file mode 100644
index 0000000..7133337
--- /dev/null
+++ b/inversion/interpolation.py
@@ -0,0 +1,190 @@
+# ------------------------------------------------------------------------------
+# Linear interpolation between inverted images and generated images.
+# ------------------------------------------------------------------------------
+
+import functools
+import h5py
+import itertools
+import numpy as np
+import os
+import pickle
+import params
+import scipy
+import sys
+import tensorflow as tf
+import tensorflow_hub as hub
+import time
+import visualize as vs
+
+# --------------------------
+# Hyper-parameters.
+# --------------------------
+if len(sys.argv) < 2:
+ sys.exit('Must provide a configuration file.')
+params = params.Params(sys.argv[1])
+
+# --------------------------
+# Global variables.
+# --------------------------
+BATCH_SIZE = params.batch_size
+SAMPLE_SIZE = params.sample_size
+SAMPLES_DIR = 'interpolation'
+INVERSES_DIR = 'inverses'
+if not os.path.exists(SAMPLES_DIR):
+ os.makedirs(SAMPLES_DIR)
+if not os.path.exists(INVERSES_DIR):
+ os.makedirs(INVERSES_DIR)
+
+# --------------------------
+# Util functions.
+# --------------------------
+def interpolate(A, B, num_interps):
+ alphas = np.linspace(0, 1., num_interps)
+ if A.shape != B.shape:
+ raise ValueError('A and B must have the same shape to interpolate.')
+ return np.array([(1-a)*A + a*B for a in alphas])
+
+# One hot encoding for classes.
+def one_hot(values):
+ return np.eye(N_CLASS)[values]
+
+# Random sampler for classes.
+def label_sampler(size=[BATCH_SIZE]):
+ return np.random.random_integers(low=0, high=N_CLASS-1, size=size)
+
+def label_hot_sampler(size=[BATCH_SIZE]):
+ return one_hot(label_sampler(size=size))
+
+# --------------------------
+# Load Graph.
+# --------------------------
+generator = hub.Module(str(params.generator_path))
+
+gen_signature = 'generator'
+if 'generator' not in generator.get_signature_names():
+ gen_signature = 'default'
+
+input_info = generator.get_input_info_dict(gen_signature)
+COND_GAN = 'y' in input_info
+
+if COND_GAN:
+ Z_DIM = input_info['z'].get_shape().as_list()[1]
+ latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+ N_CLASS = input_info['y'].get_shape().as_list()[1]
+ label = tf.get_variable(name='label', dtype=tf.float32,
+ shape=[BATCH_SIZE, N_CLASS])
+ gen_in = dict(params.generator_fixed_inputs)
+ gen_in['z'] = latent
+ gen_in['y'] = label
+ gen_img = generator(gen_in, signature=gen_signature)
+else:
+ Z_DIM = input_info['default'].get_shape().as_list()[1]
+ latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+ if (params.generator_fixed_inputs):
+ gen_in = dict(params.generator_fixed_inputs)
+ gen_in['z'] = latent
+ gen_img = generator(gen_in, signature=gen_signature)
+ else:
+ gen_img = generator(latent, signature=gen_signature)
+
+# Convert generated image to channels_first.
+gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
+
+# Override intermediate layer.
+if params.inv_layer == 'latent':
+ encoding = latent
+ ENC_SHAPE = [Z_DIM]
+else:
+ layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer
+ gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name)
+ ENC_SHAPE = gen_encoding.get_shape().as_list()[1:]
+ encoding = tf.get_variable(name='encoding', dtype=tf.float32,
+ shape=[BATCH_SIZE,] + ENC_SHAPE)
+ tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding))
+
+# Define image shape.
+IMG_SHAPE = gen_img.get_shape().as_list()[1:]
+
+# --------------------------
+# Noise source.
+# --------------------------
+def noise_sampler():
+ return np.random.normal(size=[BATCH_SIZE, Z_DIM])
+
+# --------------------------
+# Dataset.
+# --------------------------
+in_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'r')
+in_images = in_file['xtrain']
+if COND_GAN:
+ in_labels = in_file['ytrain']
+in_encoding = in_file['encoding']
+in_latent = in_file['latent']
+NUM_IMGS = in_images.shape[0] # number of images.
+
+# --------------------------
+# Training.
+# --------------------------
+# Start session.
+sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+sess.run(tf.global_variables_initializer())
+sess.run(tf.tables_initializer())
+
+for i in range(0, NUM_IMGS, BATCH_SIZE):
+ # Set label.
+ if COND_GAN:
+ sess.run(label.assign(one_hot(in_labels[i:i+BATCH_SIZE])))
+
+ # Linear interpolation between G_1(z*) and G_1(z*)+delta*.
+ sample_enc_1 = in_encoding[i:i+BATCH_SIZE]
+ out_batch = np.ndarray(shape=[8,BATCH_SIZE]+IMG_SHAPE, dtype='uint8')
+ out_batch[0] = in_images[i:i+BATCH_SIZE]
+ sess.run(latent.assign(in_latent[i:i+BATCH_SIZE]))
+ sample_enc_2 = sess.run(gen_encoding)
+ sample_enc = interpolate(sample_enc_1, sample_enc_2, 7)
+ for j in range(0,7):
+ sess.run(encoding.assign(sample_enc[j]))
+ gen_images = sess.run(gen_img)
+ gen_images = vs.data2img(gen_images)
+ out_batch[j+1] = gen_images
+
+ out_batch = np.transpose(out_batch, [1, 0, 2, 3, 4])
+ for k in range(BATCH_SIZE):
+ out_batch_k = vs.seq_transform(out_batch[k])
+ # Add white padding.
+ pad = 20
+ out_batch_kk = np.ndarray(shape=[IMG_SHAPE[1], IMG_SHAPE[1]*8+pad, 3],
+ dtype='uint8')
+ out_batch_kk[:,:IMG_SHAPE[1],:] = out_batch_k[:,:IMG_SHAPE[1],:]
+ out_batch_kk[:,IMG_SHAPE[1]:IMG_SHAPE[1]+pad,:] = 255
+ out_batch_kk[:,IMG_SHAPE[1]+pad:,:] = out_batch_k[:,IMG_SHAPE[1]:,:]
+
+ vs.save_image('{}/interpolation_delta_{}.png'.format(SAMPLES_DIR, i+k), out_batch_kk)
+ print('Saved delta interpolation for img: {}.'.format(i+k))
+
+ # Linear interpolation between G_1(z_random) and G_1(z*)+delta*.
+ sample_enc_1 = in_encoding[i:i+BATCH_SIZE]
+ sample_z_1 = in_latent[i:i+BATCH_SIZE]
+ out_batch = np.ndarray(shape=[8*8,BATCH_SIZE]+IMG_SHAPE, dtype='uint8')
+ for k in range(8):
+ sample_z_2 = noise_sampler()
+ sess.run(latent.assign(sample_z_2))
+ sample_enc_2 = sess.run(gen_encoding)
+ sample_z = interpolate(sample_z_1, sample_z_2, 8)
+ sample_enc = interpolate(sample_enc_1, sample_enc_2, 8)
+ for j in range(8):
+ sess.run(latent.assign(sample_z[j]))
+ sess.run(encoding.assign(sample_enc[j]))
+ gen_images = sess.run(gen_img)
+ gen_images = vs.data2img(gen_images)
+ out_batch[k*8+j] = gen_images
+
+ out_batch = np.transpose(out_batch, [1, 0, 2, 3, 4])
+ for k in range(BATCH_SIZE):
+ out_batch_k = vs.grid_transform(out_batch[k])
+ vs.save_image('{}/interpolation_rand_{}.png'.format(SAMPLES_DIR, i+k), out_batch_k)
+ print('Saved rand interpolation for img: {}.'.format(i+k))
+
+sess.close()
diff --git a/inversion/inversion.py b/inversion/inversion.py
new file mode 100644
index 0000000..04c033b
--- /dev/null
+++ b/inversion/inversion.py
@@ -0,0 +1,477 @@
+# ------------------------------------------------------------------------------
+# Implementation of the inverse of Generator by Gradient descent w.r.t.
+# generator's inputs, for many intermediate layers.
+# ------------------------------------------------------------------------------
+
+import glob
+import h5py
+import itertools
+import numpy as np
+import os
+import params
+import PIL
+import scipy
+import sys
+import tensorflow as tf
+import tensorflow_probability as tfp
+import tensorflow_hub as hub
+import time
+import visualize as vs
+
+# --------------------------
+# Hyper-parameters.
+# --------------------------
+# Expected parameters:
+# generator_path: path to generator module.
+# generator_fixed_inputs: dictionary of fixed generator's input parameters.
+# dataset: name of the dataset (hdf5 file).
+# dataset_out: name for the output inverted dataset (hdf5 file).
+# General parameters:
+# batch_size: number of images inverted at the same time.
+# inv_it: number of iterations to invert an image.
+# inv_layer: 'latent' or name of the tensor of the custom layer to be inverted.
+# lr: learning rate.
+# decay_lr: exponential decay on the learning rate.
+# decay_n: number of exponential decays on the learning rate.
+# custom_grad_relu: replace relus with custom gradient.
+# Logging:
+# sample_size: number of images included in sampled images.
+# save_progress: whether to save intermediate images during optimization.
+# log_z_norm: log the norm of different sections of z.
+# log_activation_layer: log the percentage of active neurons in this layer.
+# Losses:
+# mse: use the mean squared error on pixels for image comparison.
+# features: use features extracted by a feature extractor for image comparison.
+# feature_extractor_path: path to feature extractor module.
+# feature_extractor_output: output name from feature extractor.
+# likeli_loss: regularization loss on the log likelihood of encodings.
+# norm_loss: regularization loss on the norm of encodings.
+# dist_loss: whether to include a loss on the dist between g1(z) and enc.
+# lambda_mse: coefficient for mse loss.
+# lambda_feat: coefficient for features loss.
+# lambda_reg: coefficient for regularization loss on latent.
+# lambda_dist: coefficient for l1 regularization on delta.
+# Latent:
+# clipping: whether to clip encoding values after every update.
+# stochastic_clipping: whether to consider stochastic clipping.
+# clip: clipping bound.
+# pretrained_latent: load pre trained fixed latent.
+# fixed_z: do not train the latent vector.
+# Initialization:
+# init_gen_dist: initialize encodings from the generated distribution.
+# init_lo: init min value.
+# init_hi: init max value.
+if len(sys.argv) < 2:
+ sys.exit('Must provide a configuration file.')
+params = params.Params(sys.argv[1])
+
+# --------------------------
+# Global directories.
+# --------------------------
+BATCH_SIZE = params.batch_size
+SAMPLE_SIZE = params.sample_size
+LOGS_DIR = 'logs'
+SAMPLES_DIR = 'samples'
+INVERSES_DIR = 'inverses'
+if not os.path.exists(LOGS_DIR):
+ os.makedirs(LOGS_DIR)
+if not os.path.exists(SAMPLES_DIR):
+ os.makedirs(SAMPLES_DIR)
+if not os.path.exists(INVERSES_DIR):
+ os.makedirs(INVERSES_DIR)
+
+# --------------------------
+# Util functions.
+# --------------------------
+# One hot encoding for classes.
+def one_hot(values):
+ return np.eye(N_CLASS)[values]
+
+# --------------------------
+# Logging.
+# --------------------------
+summary_writer = tf.summary.FileWriter(LOGS_DIR)
+def log_stats(name, val, it):
+ summary = tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=val)])
+ summary_writer.add_summary(summary, it)
+
+# --------------------------
+# Load Graph.
+# --------------------------
+generator = hub.Module(str(params.generator_path))
+
+gen_signature = 'generator'
+if 'generator' not in generator.get_signature_names():
+ gen_signature = 'default'
+
+input_info = generator.get_input_info_dict(gen_signature)
+COND_GAN = 'y' in input_info
+
+if COND_GAN:
+ Z_DIM = input_info['z'].get_shape().as_list()[1]
+ latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+ N_CLASS = input_info['y'].get_shape().as_list()[1]
+ label = tf.get_variable(name='label', dtype=tf.float32,
+ shape=[BATCH_SIZE, N_CLASS])
+ gen_in = dict(params.generator_fixed_inputs)
+ gen_in['z'] = latent
+ gen_in['y'] = label
+ gen_img = generator(gen_in, signature=gen_signature)
+else:
+ Z_DIM = input_info['default'].get_shape().as_list()[1]
+ latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+ if (params.generator_fixed_inputs):
+ gen_in = dict(params.generator_fixed_inputs)
+ gen_in['z'] = latent
+ gen_img = generator(gen_in, signature=gen_signature)
+ else:
+ gen_img = generator(latent, signature=gen_signature)
+
+# Convert generated image to channels_first.
+gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
+
+# Override intermediate layer.
+if params.inv_layer == 'latent':
+ encoding = latent
+ ENC_SHAPE = [Z_DIM]
+else:
+ layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer
+ gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name)
+ ENC_SHAPE = gen_encoding.get_shape().as_list()[1:]
+ encoding = tf.get_variable(name='encoding', dtype=tf.float32,
+ shape=[BATCH_SIZE,] + ENC_SHAPE)
+ tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding))
+
+# Step counter.
+inv_step = tf.get_variable('inv_step', initializer=0, trainable=False)
+
+# Define target image.
+IMG_SHAPE = gen_img.get_shape().as_list()[1:]
+target = tf.get_variable(name='target', dtype=tf.int32,
+ shape=[BATCH_SIZE,] + IMG_SHAPE)
+target_img = (tf.cast(target, tf.float32) / 255.) * 2.0 - 1. # Norm to [-1, 1].
+
+# Custom Gradient for Relus.
+if params.custom_grad_relu:
+ grad_lambda = tf.train.exponential_decay(0.1, inv_step, params.inv_it / 5,
+ 0.1, staircase=False)
+ @tf.custom_gradient
+ def relu_custom_grad(x):
+ def grad(dy):
+ return tf.where(x >= 0, dy,
+ grad_lambda*tf.where(dy < 0, dy, tf.zeros_like(dy)))
+ return tf.nn.relu(x), grad
+
+ gen_scope = 'module_apply_' + gen_signature + '/'
+ for op in tf.get_default_graph().get_operations():
+ if 'Relu' in op.name and gen_scope in op.name:
+ assert len(op.inputs) == 1
+ assert len(op.outputs) == 1
+ new_out = relu_custom_grad(op.inputs[0])
+ tf.contrib.graph_editor.swap_ts(op.outputs[0], new_out)
+
+# Operations to clip the values of the encodings.
+if params.clipping or params.stochastic_clipping:
+ assert params.clip >= 0
+ if params.stochastic_clipping:
+ new_enc = tf.where(tf.abs(latent) >= params.clip,
+ tf.random.uniform([BATCH_SIZE, Z_DIM], minval=-params.clip,
+ maxval=params.clip), latent)
+ else:
+ new_enc = tf.clip_by_value(latent, -params.clip, params.clip)
+ clip_latent = tf.assign(latent, new_enc)
+
+# Monitor relu's activation.
+if params.log_activation_layer:
+ gen_scope = 'module_apply_' + gen_signature + '/'
+ activation_rate = 1.0 - tf.nn.zero_fraction(tf.get_default_graph()\
+ .get_tensor_by_name(gen_scope + params.log_activation_layer))
+
+# --------------------------
+# Reconstruction losses.
+# --------------------------
+# Mse loss for image comparison.
+if params.mse:
+ pix_square_diff = tf.square((target_img - gen_img) / 2.0)
+ mse_loss = tf.reduce_mean(pix_square_diff)
+ img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3])
+else:
+ mse_loss = tf.constant(0.0)
+ img_mse_err = tf.constant(0.0)
+
+# Use custom features for image comparison.
+if params.features:
+ feature_extractor = hub.Module(str(params.feature_extractor_path))
+
+ # Convert images from range [-1, 1] channels_first to [0, 1] channels_last.
+ gen_img_1 = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1])
+ target_img_1 = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1])
+
+ # Convert images to appropriate size for feature extraction.
+ height, width = hub.get_expected_image_size(feature_extractor)
+ gen_img_1 = tf.image.resize_images(gen_img_1, [height, width])
+ target_img_1 = tf.image.resize_images(target_img_1, [height, width])
+
+ gen_feat = feature_extractor(dict(images=gen_img_1), as_dict=True,
+ signature='image_feature_vector')[params.feature_extractor_output]
+ target_feat = feature_extractor(dict(images=target_img_1), as_dict=True,
+ signature='image_feature_vector')[params.feature_extractor_output]
+ feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat),
+ [BATCH_SIZE, -1])
+ feat_loss = tf.reduce_mean(feat_square_diff)
+ img_feat_err = tf.reduce_mean(feat_square_diff, axis=1)
+else:
+ feat_loss = tf.constant(0.0)
+ img_feat_err = tf.constant(0.0)
+
+# --------------------------
+# Regularization losses.
+# --------------------------
+# Loss on the norm of the encoding.
+if params.norm_loss:
+ dim = 20
+ chi2_dist = tfp.distributions.Chi2(dim)
+ mode = dim - 2
+ mode_log_prob = chi2_dist.log_prob(mode)
+ norm_loss = 0.0
+ for i in range(int(Z_DIM / dim)):
+ squared_l2 = tf.reduce_sum(tf.square(latent[:,i*dim:(i+1)*dim]), axis=1)
+ over_mode = tf.nn.relu(squared_l2 - mode)
+ norm_loss -= tf.reduce_mean(chi2_dist.log_prob(mode + over_mode))
+ norm_loss += mode_log_prob
+else:
+ norm_loss = tf.constant(0.0)
+
+# Loss on the likelihood of the encoding.
+if params.likeli_loss:
+ norm_dist = tfp.distributions.Normal(0.0, 1.0)
+ likeli_loss = - tf.reduce_mean(norm_dist.log_prob(latent))
+ mode_log_prob = norm_dist.log_prob(0.0)
+ likeli_loss += mode_log_prob
+else:
+ likeli_loss = tf.constant(0.0)
+
+# Regularization loss.
+reg_loss = norm_loss + likeli_loss
+
+# Loss on the l1 distance between gen_encoding and inverted encoding.
+if params.dist_loss:
+ dist_loss = tf.reduce_mean(tf.abs(encoding - gen_encoding))
+else:
+ dist_loss = tf.constant(0.0)
+
+# Per image reconstruction error.
+img_rec_err = params.lambda_mse * img_mse_err\
+ + params.lambda_feat * img_feat_err
+
+# Batch reconstruction error.
+rec_loss = params.lambda_mse * mse_loss + params.lambda_feat * feat_loss
+
+# Total inversion loss.
+inv_loss = rec_loss + params.lambda_reg * reg_loss\
+ + params.lambda_dist * dist_loss
+
+# --------------------------
+# Optimizer.
+# --------------------------
+if params.decay_lr:
+ lrate = tf.train.exponential_decay(params.lr, inv_step,
+ params.inv_it / params.decay_n, 0.1, staircase=True)
+else:
+ lrate = tf.constant(params.lr)
+trained_params = [encoding] if params.fixed_z else [latent, encoding]
+optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999)
+inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params,
+ global_step=inv_step)
+reinit_optimizer = tf.variables_initializer(optimizer.variables())
+
+# --------------------------
+# Noise source.
+# --------------------------
+def noise_sampler():
+ return np.random.normal(size=[BATCH_SIZE, Z_DIM])
+
+def small_init(shape=[BATCH_SIZE, Z_DIM]):
+ return np.random.uniform(low=params.init_lo, high=params.init_hi, size=shape)
+
+# --------------------------
+# Dataset.
+# --------------------------
+if params.dataset.endswith('.hdf5'):
+ in_file = h5py.File(params.dataset, 'r')
+ sample_images = in_file['xtrain']
+ if COND_GAN:
+ sample_labels = in_file['ytrain']
+ NUM_IMGS = sample_images.shape[0] # number of images to be inverted.
+ print("Number of images: {}".format(NUM_IMGS))
+ def sample_images_gen():
+ for i in range(int(NUM_IMGS / BATCH_SIZE)):
+ i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE
+ if COND_GAN:
+ yield sample_images[i_1:i_2], sample_labels[i_1:i_2]
+ else:
+ yield sample_images[i_1:i_2], np.zeros(BATCH_SIZE)
+ image_gen = sample_images_gen()
+ if 'latent' in in_file:
+ sample_latents = in_file['latent']
+ def sample_latent_gen():
+ for i in range(int(NUM_IMGS / BATCH_SIZE)):
+ i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE
+ yield sample_latents[i_1:i_2]
+ latent_gen = sample_latent_gen()
+ assert(NUM_IMGS % BATCH_SIZE == 0)
+else:
+ sys.exit('Unknown dataset {}.'.format(params.dataset))
+
+NUM_IMGS -= NUM_IMGS % BATCH_SIZE
+
+# --------------------------
+# Training.
+# --------------------------
+# Start session.
+sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+sess.run(tf.global_variables_initializer())
+sess.run(tf.tables_initializer())
+
+# Output file.
+out_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'w')
+out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE,
+ dtype='uint8')
+out_enc = out_file.create_dataset('encoding', [NUM_IMGS,] + ENC_SHAPE)
+out_lat = out_file.create_dataset('latent', [NUM_IMGS, Z_DIM])
+if COND_GAN:
+ out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32')
+out_err = out_file.create_dataset('err', (NUM_IMGS,))
+
+# Gradient descent w.r.t. generator's inputs.
+it = 0
+out_pos = 0
+start_time = time.time()
+
+for image_batch, label_batch in image_gen:
+
+ # Save target.
+ sess.run(target.assign(image_batch))
+ if COND_GAN:
+ sess.run(label.assign(one_hot(label_batch)))
+
+ # Initialize encodings to random values.
+ if params.pre_trained_latent:
+ sess.run(latent.assign(latent_gen.next()))
+ if params.inv_layer != 'latent':
+ sess.run(encoding.assign(gen_encoding))
+ else:
+ if params.init_gen_dist:
+ sess.run(latent.assign(noise_sampler()))
+ if params.inv_layer != 'latent':
+ sess.run(encoding.assign(gen_encoding))
+ else:
+ sess.run(latent.assign(small_init()))
+ if params.inv_layer != 'latent':
+ sess.run(encoding.assign(small_init(shape=[BATCH_SIZE,] + ENC_SHAPE)))
+
+ # Init optimizer.
+ sess.run(inv_step.assign(0))
+ sess.run(reinit_optimizer)
+
+ # Main optimization loop.
+ print("Total iterations: {}".format(params.inv_it))
+ for _ in range(params.inv_it):
+
+ _inv_loss, _mse_loss, _feat_loss, _rec_loss, _reg_loss, _dist_loss,\
+ _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss,
+ rec_loss, reg_loss, dist_loss, lrate, inv_train_op])
+
+ if params.clipping or params.stochastic_clipping:
+ sess.run(clip_latent)
+
+ # Every 100 iterations save logs with training information.
+ if it % 100 == 99:
+ # Log losses.
+ etime = time.time() - start_time
+ print('It [{:8d}] time [{:5.1f}] total [{:.4f}] mse [{:.4f}] '
+ 'feat [{:.4f}] rec [{:.4f}] reg [{:.4f}] dist [{:.4f}] '
+ 'lr [{:.4f}]'.format(it, etime, _inv_loss, _mse_loss,
+ _feat_loss, _rec_loss, _reg_loss, _dist_loss, _lrate))
+
+ if params.log_z_norm:
+ _lat = sess.run(latent)
+ dim = 20 if Z_DIM == 120 else Z_DIM
+ for i in range(int(Z_DIM/dim)):
+ _subset = _lat[:,i*dim:(i+1)*dim]
+ print('section {:1d}: norm={:.4f} (exp={:.4f}) min={:.4f} max={:.4f}'\
+ .format(i, np.mean(np.linalg.norm(_subset, axis=1)),
+ np.sqrt(dim-2), np.min(_subset), np.max(_subset)))
+
+ if params.log_activation_layer:
+ _act_rate = sess.run(activation_rate)
+ print('activation_rate={:.4f}'.format(_act_rate))
+ log_stats('activation rate', _act_rate, it)
+
+ sys.stdout.flush()
+
+ # Log tensorboard's statistics.
+ log_stats('total loss', _inv_loss, it)
+ log_stats('mse loss', _mse_loss, it)
+ log_stats('feat loss', _feat_loss, it)
+ log_stats('rec loss', _rec_loss, it)
+ log_stats('reg loss', _reg_loss, it)
+ log_stats('dist loss', _dist_loss, it)
+ log_stats('out pos', out_pos, it)
+ log_stats('lrate', _lrate, it)
+ summary_writer.flush()
+
+ # Save target images and reconstructions.
+ if params.save_progress:
+ assert SAMPLE_SIZE <= BATCH_SIZE
+ gen_images = sess.run(gen_img)
+ inv_batch = vs.interleave(image_batch[BATCH_SIZE - SAMPLE_SIZE:],
+ vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:]))
+ inv_batch = vs.grid_transform(inv_batch)
+ vs.save_image('{}/progress_{}.png'.format(SAMPLES_DIR, it), inv_batch)
+
+ # Save linear interpolation between the actual and generated encodings.
+ if params.dist_loss and it % 1000 == 999:
+ enc_batch, gen_enc = sess.run([encoding, gen_encoding])
+ for j in range(10):
+ custom_enc = gen_enc * (1-(j/10.0)) + enc_batch * (j/10.0)
+ sess.run(encoding.assign(custom_enc))
+ gen_images = sess.run(gen_img)
+ inv_batch = vs.interleave(image_batch[BATCH_SIZE - SAMPLE_SIZE:],
+ vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:]))
+ inv_batch = vs.grid_transform(inv_batch)
+ vs.save_image('{}/progress_{}_lat_{}.png'.format(SAMPLES_DIR,it,j),
+ inv_batch)
+ sess.run(encoding.assign(enc_batch))
+
+ # It counter.
+ it += 1
+
+ # Save samples of inverted images.
+ if SAMPLE_SIZE > 0:
+ assert SAMPLE_SIZE <= BATCH_SIZE
+ gen_images = sess.run(gen_img)
+ inv_batch = vs.interleave(image_batch[BATCH_SIZE - SAMPLE_SIZE:],
+ vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:]))
+ inv_batch = vs.grid_transform(inv_batch)
+ vs.save_image('{}/{}.png'.format(SAMPLES_DIR, out_pos), inv_batch)
+ print('Saved samples for out_pos: {}.'.format(out_pos))
+
+ # Save images that are ready.
+ latent_batch, enc_batch, rec_err_batch =\
+ sess.run([latent, encoding, img_rec_err])
+ out_lat[out_pos:out_pos+BATCH_SIZE] = latent_batch
+ out_enc[out_pos:out_pos+BATCH_SIZE] = enc_batch
+ out_images[out_pos:out_pos+BATCH_SIZE] = image_batch
+ if COND_GAN:
+ out_labels[out_pos:out_pos+BATCH_SIZE] = label_batch
+ out_err[out_pos:out_pos+BATCH_SIZE] = rec_err_batch
+ out_pos += BATCH_SIZE
+
+print('Mean reconstruction error: {}'.format(np.mean(out_err)))
+print('Stdev reconstruction error: {}'.format(np.std(out_err)))
+print('End of inversion.')
+out_file.close()
+sess.close()
diff --git a/inversion/params.py b/inversion/params.py
new file mode 100644
index 0000000..dd9a358
--- /dev/null
+++ b/inversion/params.py
@@ -0,0 +1,27 @@
+# ------------------------------------------------------------------------------
+# Util class for hyperparams.
+# ------------------------------------------------------------------------------
+
+import json
+
+class Params():
+ """Class that loads hyperparameters from a json file."""
+
+ def __init__(self, json_path):
+ self.update(json_path)
+
+ def save(self, json_path):
+ """Saves parameters to json file."""
+ with open(json_path, 'w') as f:
+ json.dump(self.__dict__, f, indent=4)
+
+ def update(self, json_path):
+ """Loads parameters from json file."""
+ with open(json_path) as f:
+ params = json.load(f)
+ self.__dict__.update(params)
+
+ @property
+ def dict(self):
+ """Gives dict-like access to Params instance."""
+ return self.__dict__
diff --git a/inversion/params_dense-512.json b/inversion/params_dense-512.json
new file mode 100644
index 0000000..cbe7a3e
--- /dev/null
+++ b/inversion/params_dense-512.json
@@ -0,0 +1,40 @@
+{
+ "decay_n": 2,
+ "features": true,
+ "clip": 1.0,
+ "stochastic_clipping": false,
+ "clipping": false,
+ "dataset": "inverses/dataset.encodings.hdf5",
+ "inv_layer": "Generator_2/G_Z/Reshape:0",
+ "decay_lr": true,
+ "inv_it": 15000,
+ "generator_path": "https://tfhub.dev/deepmind/biggan-512/2",
+ "attention_map_layer": "Generator_2/attention/Softmax:0",
+ "pre_trained_latent": true,
+ "lambda_dist": 10.0,
+ "likeli_loss": false,
+ "init_hi": 0.001,
+ "lr": 0.01,
+ "norm_loss": false,
+ "generator_fixed_inputs": {
+ "truncation": 1.0
+ },
+ "log_z_norm": false,
+ "feature_extractor_path": "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1",
+ "mse": true,
+ "custom_grad_relu": false,
+ "random_label": false,
+ "lambda_feat": 1.0,
+ "init_gen_dist": false,
+ "log_activation_layer": "Generator_2/GBlock/Relu:0",
+ "batch_size": 4,
+ "fixed_z": true,
+ "feature_extractor_output": "InceptionV3/Mixed_7a",
+ "init_lo": -0.001,
+ "lambda_mse": 1.0,
+ "lambda_reg": 0.1,
+ "dist_loss": true,
+ "sample_size": 4,
+ "out_dataset": "dataset.encodings.encodings.hdf5",
+ "save_progress": true
+}
diff --git a/inversion/params_dense.json b/inversion/params_dense.json
new file mode 100644
index 0000000..73bfbe8
--- /dev/null
+++ b/inversion/params_dense.json
@@ -0,0 +1,40 @@
+{
+ "decay_n": 2,
+ "features": true,
+ "clip": 1.0,
+ "stochastic_clipping": false,
+ "clipping": false,
+ "dataset": "inverses/dataset.encodings.hdf5",
+ "inv_layer": "Generator_2/G_Z/Reshape:0",
+ "decay_lr": true,
+ "inv_it": 15000,
+ "generator_path": "https://tfhub.dev/deepmind/biggan-128/2",
+ "attention_map_layer": "Generator_2/attention/Softmax:0",
+ "pre_trained_latent": true,
+ "lambda_dist": 10.0,
+ "likeli_loss": false,
+ "init_hi": 0.001,
+ "lr": 0.01,
+ "norm_loss": false,
+ "generator_fixed_inputs": {
+ "truncation": 1.0
+ },
+ "log_z_norm": false,
+ "feature_extractor_path": "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1",
+ "mse": true,
+ "custom_grad_relu": false,
+ "random_label": false,
+ "lambda_feat": 1.0,
+ "init_gen_dist": false,
+ "log_activation_layer": "Generator_2/GBlock/Relu:0",
+ "batch_size": 25,
+ "fixed_z": true,
+ "feature_extractor_output": "InceptionV3/Mixed_7a",
+ "init_lo": -0.001,
+ "lambda_mse": 1.0,
+ "lambda_reg": 0.1,
+ "dist_loss": true,
+ "sample_size": 25,
+ "out_dataset": "dataset.encodings.encodings.hdf5",
+ "save_progress": true
+}
diff --git a/inversion/params_latent-512.json b/inversion/params_latent-512.json
new file mode 100644
index 0000000..5fa0c19
--- /dev/null
+++ b/inversion/params_latent-512.json
@@ -0,0 +1,40 @@
+{
+ "decay_n": 2,
+ "features": true,
+ "clip": 1.0,
+ "stochastic_clipping": false,
+ "clipping": false,
+ "dataset": "inverses/dataset.hdf5",
+ "inv_layer": "latent",
+ "decay_lr": true,
+ "inv_it": 15000,
+ "generator_path": "https://tfhub.dev/deepmind/biggan-512/2",
+ "attention_map_layer": "Generator_2/attention/Softmax:0",
+ "pre_trained_latent": false,
+ "lambda_dist": 0.0,
+ "likeli_loss": true,
+ "init_hi": 0.001,
+ "lr": 0.1,
+ "norm_loss": false,
+ "generator_fixed_inputs": {
+ "truncation": 1.0
+ },
+ "log_z_norm": true,
+ "feature_extractor_path": "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1",
+ "mse": true,
+ "custom_grad_relu": false,
+ "random_label": false,
+ "lambda_feat": 1.0,
+ "init_gen_dist": false,
+ "log_activation_layer": "Generator_2/GBlock/Relu:0",
+ "batch_size": 4,
+ "fixed_z": false,
+ "feature_extractor_output": "InceptionV3/Mixed_7a",
+ "init_lo": -0.001,
+ "lambda_mse": 1.0,
+ "lambda_reg": 0.1,
+ "dist_loss": false,
+ "sample_size": 4,
+ "out_dataset": "dataset.encodings.hdf5",
+ "save_progress": true
+}
diff --git a/inversion/params_latent.json b/inversion/params_latent.json
new file mode 100644
index 0000000..4642b52
--- /dev/null
+++ b/inversion/params_latent.json
@@ -0,0 +1,40 @@
+{
+ "decay_n": 2,
+ "features": true,
+ "clip": 1.0,
+ "stochastic_clipping": false,
+ "clipping": false,
+ "dataset": "inverses/dataset.hdf5",
+ "inv_layer": "latent",
+ "decay_lr": true,
+ "inv_it": 15000,
+ "generator_path": "https://tfhub.dev/deepmind/biggan-128/2",
+ "attention_map_layer": "Generator_2/attention/Softmax:0",
+ "pre_trained_latent": false,
+ "lambda_dist": 0.0,
+ "likeli_loss": true,
+ "init_hi": 0.001,
+ "lr": 0.1,
+ "norm_loss": false,
+ "generator_fixed_inputs": {
+ "truncation": 1.0
+ },
+ "log_z_norm": true,
+ "feature_extractor_path": "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1",
+ "mse": true,
+ "custom_grad_relu": false,
+ "random_label": false,
+ "lambda_feat": 1.0,
+ "init_gen_dist": false,
+ "log_activation_layer": "Generator_2/GBlock/Relu:0",
+ "batch_size": 25,
+ "fixed_z": false,
+ "feature_extractor_output": "InceptionV3/Mixed_7a",
+ "init_lo": -0.001,
+ "lambda_mse": 1.0,
+ "lambda_reg": 0.1,
+ "dist_loss": false,
+ "sample_size": 25,
+ "out_dataset": "dataset.encodings.hdf5",
+ "save_progress": true
+}
diff --git a/inversion/random_sample-512.json b/inversion/random_sample-512.json
new file mode 100644
index 0000000..4c5dbd0
--- /dev/null
+++ b/inversion/random_sample-512.json
@@ -0,0 +1,12 @@
+{
+ "generator_path": "https://tfhub.dev/deepmind/biggan-512/2",
+ "batch_size": 16,
+ "sample_size": 16,
+ "custom_label": 0,
+ "dataset_out": "dataset.hdf5",
+ "generator_fixed_inputs": {
+ "truncation": 1.0
+ },
+ "num_imgs": 16,
+ "random_label": true
+}
diff --git a/inversion/random_sample.json b/inversion/random_sample.json
new file mode 100644
index 0000000..c93aa1f
--- /dev/null
+++ b/inversion/random_sample.json
@@ -0,0 +1,12 @@
+{
+ "generator_path": "https://tfhub.dev/deepmind/biggan-128/2",
+ "batch_size": 20,
+ "sample_size": 20,
+ "custom_label": 0,
+ "dataset_out": "dataset.hdf5",
+ "generator_fixed_inputs": {
+ "truncation": 1.0
+ },
+ "num_imgs": 1000,
+ "random_label": true
+}
diff --git a/inversion/random_sample.py b/inversion/random_sample.py
new file mode 100644
index 0000000..61cac9c
--- /dev/null
+++ b/inversion/random_sample.py
@@ -0,0 +1,144 @@
+# ------------------------------------------------------------------------------
+# Generate random samples of the generator and save the images to a hdf5 file.
+# ------------------------------------------------------------------------------
+
+import h5py
+import numpy as np
+import os
+import params
+import sys
+import tensorflow as tf
+import tensorflow_hub as hub
+import time
+import visualize as vs
+
+# --------------------------
+# Hyper-parameters.
+# --------------------------
+# Expected parameters:
+# generator_path: path to generator module.
+# generator_fixed_inputs: dictionary of fixed generator's input parameters.
+# dataset_out: name for the output created dataset (hdf5 file).
+# General parameters:
+# batch_size: number of images generated at the same time.
+# random_label: choose random labels.
+# num_imgs: number of instances to generate.
+# custom_label: custom label to be fixed.
+# Logging:
+# sample_size: number of images included in sampled images.
+if len(sys.argv) < 2:
+ sys.exit('Must provide a configuration file.')
+params = params.Params(sys.argv[1])
+
+# --------------------------
+# Hyper-parameters.
+# --------------------------
+# General parameters.
+BATCH_SIZE = params.batch_size
+SAMPLE_SIZE = params.sample_size
+assert SAMPLE_SIZE <= BATCH_SIZE
+NUM_IMGS = params.num_imgs
+
+# --------------------------
+# Global directories.
+# --------------------------
+SAMPLES_DIR = 'random_samples'
+INVERSES_DIR = 'inverses'
+if not os.path.exists(SAMPLES_DIR):
+ os.makedirs(SAMPLES_DIR)
+if not os.path.exists(INVERSES_DIR):
+ os.makedirs(INVERSES_DIR)
+
+# --------------------------
+# Util functions.
+# --------------------------
+def one_hot(values):
+ return np.eye(N_CLASS)[values]
+
+def label_sampler(size=1):
+ return np.random.random_integers(low=0, high=N_CLASS-1, size=size)
+
+# --------------------------
+# Load Graph.
+# --------------------------
+generator = hub.Module(str(params.generator_path))
+
+gen_signature = 'generator'
+if 'generator' not in generator.get_signature_names():
+ gen_signature = 'default'
+
+input_info = generator.get_input_info_dict(gen_signature)
+COND_GAN = 'y' in input_info
+
+if COND_GAN:
+ Z_DIM = input_info['z'].get_shape().as_list()[1]
+ latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+ N_CLASS = input_info['y'].get_shape().as_list()[1]
+ label = tf.get_variable(name='label', dtype=tf.float32,
+ shape=[BATCH_SIZE, N_CLASS])
+ gen_in = dict(params.generator_fixed_inputs)
+ gen_in['z'] = latent
+ gen_in['y'] = label
+ gen_img = generator(gen_in, signature=gen_signature)
+else:
+ Z_DIM = input_info['default'].get_shape().as_list()[1]
+ latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+ if (params.generator_fixed_inputs):
+ gen_in = dict(params.generator_fixed_inputs)
+ gen_in['z'] = latent
+ gen_img = generator(gen_in, signature=gen_signature)
+ else:
+ gen_img = generator(latent, signature=gen_signature)
+
+# Convert generated image to channels_first.
+gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
+
+# Define image shape.
+IMG_SHAPE = gen_img.get_shape().as_list()[1:]
+
+# --------------------------
+# Noise source.
+# --------------------------
+def noise_sampler():
+ return np.random.normal(size=[BATCH_SIZE, Z_DIM])
+
+# --------------------------
+# Generation.
+# --------------------------
+# Start session.
+sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+sess.run(tf.global_variables_initializer())
+sess.run(tf.tables_initializer())
+
+# Output file.
+out_file = h5py.File(os.path.join(INVERSES_DIR, params.dataset_out), 'w')
+out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE,
+ dtype='uint8')
+if COND_GAN:
+ out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32')
+
+for i in range(0, NUM_IMGS, BATCH_SIZE):
+ n_encs = min(BATCH_SIZE, NUM_IMGS - i)
+
+ if COND_GAN:
+ if params.random_label:
+ label_batch = label_sampler(BATCH_SIZE)
+ else:
+ label_batch = [params.custom_label]*BATCH_SIZE
+ sess.run(label.assign(one_hot(label_batch)))
+
+ sess.run(latent.assign(noise_sampler()))
+
+ gen_images = sess.run(gen_img)
+
+ gen_images = vs.data2img(gen_images)
+
+ out_images[i:i+n_encs] = gen_images[:n_encs]
+ if COND_GAN:
+ out_labels[i:i+n_encs] = label_batch[:n_encs]
+
+ out_batch = vs.grid_transform(gen_images[:SAMPLE_SIZE])
+ vs.save_image('{}/generated_{}.png'.format(SAMPLES_DIR, i), out_batch)
+ print('Saved samples for imgs: {}-{}.'.format(i,i+n_encs))
diff --git a/inversion/segmentation.py b/inversion/segmentation.py
new file mode 100644
index 0000000..42f2f9e
--- /dev/null
+++ b/inversion/segmentation.py
@@ -0,0 +1,191 @@
+# ------------------------------------------------------------------------------
+# Cluster the attention map of inverted images for unsupervised segmentation.
+# ------------------------------------------------------------------------------
+
+import h5py
+import numpy as np
+import os
+import params
+import scipy
+import scipy.cluster.hierarchy
+from sklearn.cluster import AgglomerativeClustering
+import sys
+import tensorflow as tf
+import tensorflow_hub as hub
+import time
+import visualize as vs
+
+if len(sys.argv) < 2:
+ sys.exit('Must provide a configuration file.')
+
+params = params.Params(sys.argv[1])
+params.batch_size = 1
+params.sample_size = 1
+
+# --------------------------
+# Global directories.
+# --------------------------
+BATCH_SIZE = params.batch_size
+SAMPLE_SIZE = params.sample_size
+SAMPLES_DIR = 'attention'
+INVERSES_DIR = 'inverses'
+if not os.path.exists(SAMPLES_DIR):
+ os.makedirs(SAMPLES_DIR)
+
+# --------------------------
+# Util functions.
+# --------------------------
+# One hot encoding for classes.
+def one_hot(values):
+ return np.eye(N_CLASS)[values]
+
+def segment_img(diss_matrix, n_clusters):
+ # Cluster image based on the information from the attention map.
+ clustering = AgglomerativeClustering(n_clusters=n_clusters,
+ affinity='precomputed', linkage='average')
+ clustering.fit(diss_matrix)
+ labels = clustering.labels_
+
+ # Upsample segmentation (from 64x64 to 128x128) and create an image where each
+ # segment has the average color of its members.
+ labels = np.broadcast_to(labels.reshape(64, 1, 64, 1), (64, 2, 64, 2))\
+ .reshape(128*128)
+ labels = np.eye(labels.max() + 1)[labels]
+ cluster_col = np.matmul(labels.T,
+ np.transpose(_gen_img, [0, 2, 3, 1]).reshape(128*128, 3))
+ cluster_count = labels.T.sum(axis=1).reshape(-1, 1)
+ labels_img = np.matmul(labels, cluster_col) / np.matmul(labels, cluster_count)
+ labels_img = np.transpose(labels_img, [1, 0]).reshape(1,3,128,128)
+ return vs.data2img(labels_img)
+
+# --------------------------
+# Load Graph.
+# --------------------------
+generator = hub.Module(str(params.generator_path))
+
+gen_signature = 'generator'
+if 'generator' not in generator.get_signature_names():
+ gen_signature = 'default'
+
+input_info = generator.get_input_info_dict(gen_signature)
+COND_GAN = 'y' in input_info
+
+if COND_GAN:
+ Z_DIM = input_info['z'].get_shape().as_list()[1]
+ latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+ N_CLASS = input_info['y'].get_shape().as_list()[1]
+ label = tf.get_variable(name='label', dtype=tf.float32,
+ shape=[BATCH_SIZE, N_CLASS])
+ gen_in = dict(params.generator_fixed_inputs)
+ gen_in['z'] = latent
+ gen_in['y'] = label
+ gen_img = generator(gen_in, signature=gen_signature)
+else:
+ Z_DIM = input_info['default'].get_shape().as_list()[1]
+ latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+ if (params.generator_fixed_inputs):
+ gen_in = dict(params.generator_fixed_inputs)
+ gen_in['z'] = latent
+ gen_img = generator(gen_in, signature=gen_signature)
+ else:
+ gen_img = generator(latent, signature=gen_signature)
+
+# Convert generated image to channels_first.
+gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
+
+# Override intermediate layer.
+if params.inv_layer == 'latent':
+ encoding = latent
+ ENC_SHAPE = [Z_DIM]
+else:
+ layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer
+ gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name)
+ ENC_SHAPE = gen_encoding.get_shape().as_list()[1:]
+ encoding = tf.get_variable(name='encoding', dtype=tf.float32,
+ shape=[BATCH_SIZE,] + ENC_SHAPE)
+ tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding))
+
+# Get attention map.
+att_map_name = 'module_apply_' + gen_signature + '/' + params.attention_map_layer
+att_map = tf.get_default_graph().get_tensor_by_name(att_map_name)
+
+# Define image shape.
+IMG_SHAPE = gen_img.get_shape().as_list()[1:]
+
+# --------------------------
+# Dataset.
+# --------------------------
+if params.out_dataset.endswith('.hdf5'):
+ in_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'r')
+ sample_images = in_file['xtrain']
+ if COND_GAN:
+ sample_labels = in_file['ytrain']
+ sample_latents = in_file['latent']
+ sample_encodings = in_file['encoding']
+ NUM_IMGS = sample_images.shape[0] # number of images.
+ def sample_images_gen():
+ for i in xrange(NUM_IMGS / BATCH_SIZE):
+ i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE
+ if COND_GAN:
+ label_batch = sample_labels[i_1:i_2]
+ else:
+ label_batch = np.zeros(BATCH_SIZE)
+ yield sample_images[i_1:i_2], label_batch, sample_latents[i_1:i_2],\
+ sample_encodings[i_1:i_2]
+ image_gen = sample_images_gen()
+else:
+ sys.exit('Unknown dataset {}.'.format(params.out_dataset))
+
+NUM_IMGS -= NUM_IMGS % BATCH_SIZE
+
+# --------------------------
+# Training.
+# --------------------------
+# Start session.
+sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+sess.run(tf.global_variables_initializer())
+sess.run(tf.tables_initializer())
+
+# Export attention map for reconstructed images.
+it = 0
+out_pos = 0
+start_time = time.time()
+
+for image_batch, label_batch, lat_batch, enc_batch in image_gen:
+
+ # Set target label.
+ if COND_GAN:
+ sess.run(label.assign(one_hot(label_batch)))
+
+ # Initialize encodings.
+ sess.run(latent.assign(lat_batch))
+ sess.run(encoding.assign(enc_batch))
+
+ # Get attention map.
+ _att_map, _gen_img = sess.run([att_map, gen_img])
+
+ # Upsampling (from 32x32 to 64x64).
+ _att_map = np.broadcast_to(_att_map.reshape(64,64,32,1,32,1),
+ (64,64,32,2,32,2)).reshape(4096,4096)
+
+ # Define dissimilarity matrix.
+ dissimilarity = 1.0 - (_att_map + _att_map.T) / 2.0
+ dissimilarity *= (np.ones((4096,4096)) - np.identity(4096))
+
+ # Segment the image with different number of clusters.
+ seg_img_8 = segment_img(dissimilarity, 8)
+ seg_img_20 = segment_img(dissimilarity, 20)
+ seg_img_40 = segment_img(dissimilarity, 40)
+
+ # Save segmentation.
+ out_batch_1 = vs.interleave(image_batch, seg_img_20)
+ out_batch_2 = vs.interleave(seg_img_8, seg_img_40)
+ out_batch = vs.interleave(out_batch_1, out_batch_2)
+ out_batch = vs.seq_transform(out_batch)
+ vs.save_image('{}/segmented_img_{}.png'.format(SAMPLES_DIR, it), out_batch)
+
+ it += 1
+
+sess.close()
diff --git a/inversion/visualize.py b/inversion/visualize.py
new file mode 100644
index 0000000..07aea2d
--- /dev/null
+++ b/inversion/visualize.py
@@ -0,0 +1,88 @@
+# ------------------------------------------------------------------------------
+# Util functions to visualize images.
+# ------------------------------------------------------------------------------
+
+import numpy as np
+import cv2 as cv
+from PIL import Image
+
+def split(x):
+ assert type(x) == int
+ t = int(np.floor(np.sqrt(x)))
+ for a in range(t, 0, -1):
+ if x % a == 0:
+ return a, x / a
+
+def grid_transform(x):
+ n, c, h, w = x.shape
+ a, b = split(n)
+ x = np.transpose(x, [0, 2, 3, 1])
+ x = np.reshape(x, [int(a), int(b), int(h), int(w), int(c)])
+ x = np.transpose(x, [0, 2, 1, 3, 4])
+ x = np.reshape(x, [int(a * h), int(b * w), int(c)])
+ if x.shape[2] == 1:
+ x = np.squeeze(x, axis=2)
+ return x
+
+def seq_transform(x):
+ n, c, h, w = x.shape
+ x = np.transpose(x, [2, 0, 3, 1])
+ x = np.reshape(x, [h, n * w, c])
+ return x
+
+# Converts image pixels from range [-1, 1] to [0, 255].
+def data2img(data):
+ rescaled = np.divide(data + 1.0, 2.0) * 255.
+ rescaled = np.clip(rescaled, 0, 255)
+ return np.rint(rescaled).astype('uint8')
+
+def interleave(a, b):
+ res = np.empty([a.shape[0] + b.shape[0]] + list(a.shape[1:]), dtype=a.dtype)
+ res[0::2] = a
+ res[1::2] = b
+ return res
+
+def save_image(filepath, img):
+ pilimg = Image.fromarray(img)
+ pilimg.save(filepath)
+
+def imread(filename):
+ img = cv.imread(filename, cv.IMREAD_UNCHANGED)
+ if img is not None:
+ if len(img.shape) > 2:
+ img = img[...,::-1]
+ return img
+
+def imconvert_float32(im):
+ im = np.float32(im)
+ im = (im / 256) * 2.0 - 1
+ return im
+
+def load_image(opt_fp_in, opt_dims=128):
+ target_im = imread(opt_fp_in)
+ w = target_im.shape[1]
+ h = target_im.shape[0]
+ if w <= h:
+ scale = opt_dims / w
+ else:
+ scale = opt_dims / h
+ target_im = cv.resize(target_im,(0,0), fx=scale, fy=scale)
+ w = target_im.shape[1]
+ h = target_im.shape[0]
+
+ x0 = 0
+ x1 = opt_dims
+ y0 = 0
+ y1 = opt_dims
+ if w > opt_dims:
+ x0 += int((w - opt_dims) / 2)
+ x1 += x0
+ if h > opt_dims:
+ y0 += int((h - opt_dims) / 2)
+ y1 += y0
+ phi_target = imconvert_float32(target_im)
+ phi_target = phi_target[y0:y1,x0:x1]
+ if phi_target.shape[2] == 4:
+ phi_target = phi_target[:,:,1:4]
+ phi_target = np.expand_dims(phi_target, 0)
+ return phi_target