diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2019-12-08 21:43:30 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2019-12-08 21:43:30 +0100 |
| commit | fb70ab05768fa4a54358dc1f304b68bc7aff6dae (patch) | |
| tree | 6ba4c805ce37b5b8827b08946f0b22f639fa3e14 | |
| parent | 326db345db13b1ab3a76406644654cb78b4d1b8d (diff) | |
inversion json files
| -rwxr-xr-x | .gitignore | 4 | ||||
| -rw-r--r-- | inversion/LICENSE | 21 | ||||
| -rw-r--r-- | inversion/README.md | 43 | ||||
| -rw-r--r-- | inversion/image_sample.py | 295 | ||||
| -rw-r--r-- | inversion/interpolation.py | 190 | ||||
| -rw-r--r-- | inversion/inversion.py | 477 | ||||
| -rw-r--r-- | inversion/params.py | 27 | ||||
| -rw-r--r-- | inversion/params_dense-512.json | 40 | ||||
| -rw-r--r-- | inversion/params_dense.json | 40 | ||||
| -rw-r--r-- | inversion/params_latent-512.json | 40 | ||||
| -rw-r--r-- | inversion/params_latent.json | 40 | ||||
| -rw-r--r-- | inversion/random_sample-512.json | 12 | ||||
| -rw-r--r-- | inversion/random_sample.json | 12 | ||||
| -rw-r--r-- | inversion/random_sample.py | 144 | ||||
| -rw-r--r-- | inversion/segmentation.py | 191 | ||||
| -rw-r--r-- | inversion/visualize.py | 88 |
16 files changed, 1664 insertions, 0 deletions
@@ -64,3 +64,7 @@ data_store/imagenet/imagenet_images/ *.zip +*.hdf5 + +events.out.* + diff --git a/inversion/LICENSE b/inversion/LICENSE new file mode 100644 index 0000000..b868d12 --- /dev/null +++ b/inversion/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 MarcosPividori + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/inversion/README.md b/inversion/README.md new file mode 100644 index 0000000..3be7b8d --- /dev/null +++ b/inversion/README.md @@ -0,0 +1,43 @@ +Exploiting GAN Internal Capacity for High-Quality Reconstruction of Natural Images +================================================================================== + +Code for reproducing experiments in ["Exploiting GAN Internal Capacity for High-Quality Reconstruction of Natural Images"](https://arxiv.org/abs/1911.05630) + +This directory contains associated source code to invert BigGAN generator for +128x128 resolution. Requires Tensorflow. + +## Generation of Random Samples: +Generate 1000 random samples of BigGAN generator: +```console + $> python random_sample.py random_sample.json +``` + +## Inversion of the Generator: +The optimization is split into two steps according to the paper: +First step, invesion to the latent space: +```console + $> python inversion.py params_latent.json +``` + +Second step, inversion to the dense layer: +```console + $> python inversion.py params_dense.json +``` + +## Interpolation: +Generate interpolations between the inverted images and generated images: +```console + $> python interpolation.py params_dense.json +``` + +## Segmentation: +Segment inverted images by clustering the attention map: +```console + $> python segmentation.py params_dense.json +``` + +Note: to replicate the experiments on real images from ImageNet, first +a hdf5 file must be created with random images from the dataset, similar to the +procedure in "random_sample.py". Then, the two step of optimization must be +executed (modify the "dataset:" parameter in params_latent.json to consider +custom datasets). diff --git a/inversion/image_sample.py b/inversion/image_sample.py new file mode 100644 index 0000000..83622a1 --- /dev/null +++ b/inversion/image_sample.py @@ -0,0 +1,295 @@ +# ------------------------------------------------------------------------------ +# Generate random samples of the generator and save the images to a hdf5 file. +# ------------------------------------------------------------------------------ + +import h5py +import numpy as np +import os +import sys +import tensorflow as tf +import tensorflow_hub as hub +import time +import visualize as vs +import argparse +from glob import glob + +# -------------------------- +# Hyper-parameters. +# -------------------------- +# Expected parameters: +# generator_path: path to generator module. +# generator_fixed_inputs: dictionary of fixed generator's input parameters. +# dataset_out: name for the output created dataset (hdf5 file). +# General parameters: +# batch_size: number of images generated at the same time. +# random_label: choose random labels. +# num_imgs: number of instances to generate. +# custom_label: custom label to be fixed. +# Logging: +# sample_size: number of images included in sampled images. +if len(sys.argv) < 2: + sys.exit('Must provide a configuration file.') + +parser = argparse.ArgumentParser(description='Initialize the image search.') +parser.add_argument('--input_dir', required=True, help='Input directory of images') +parser.add_argument('--tag', default=str(int(time.time())), help='Tag this build') +params = parser.parse_args() + +# -------------------------- +# Hyper-parameters. +# -------------------------- +# General parameters. +BATCH_SIZE = 20 +SAMPLE_SIZE = 20 +assert SAMPLE_SIZE <= BATCH_SIZE + +INVERSION_ITERATIONS = 1000 + +# -------------------------- +# Global directories. +# -------------------------- +DATASET_OUT = "{}_dataset.hdf5".format(params.tag) +SAMPLES_DIR = './outputs/{}/samples'.format(params.tag) +INVERSES_DIR = './outputs/{}/inverses'.format(params.tag) +os.makedirs(SAMPLES_DIR, exist_ok=True) +os.makedirs(INVERSES_DIR, exist_ok=True) + +# -------------------------- +# Load Graph. +# -------------------------- +generator = hub.Module('https://tfhub.dev/deepmind/biggan-128/2') + +gen_signature = 'generator' +if 'generator' not in generator.get_signature_names(): + gen_signature = 'default' + +input_info = generator.get_input_info_dict(gen_signature) + +Z_DIM = input_info['z'].get_shape().as_list()[1] +latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) +N_CLASS = input_info['y'].get_shape().as_list()[1] +label = tf.get_variable(name='label', dtype=tf.float32, + shape=[BATCH_SIZE, N_CLASS]) +gen_in = {} +gen_in['truncation'] = 1.0 +gen_in['z'] = latent +gen_in['y'] = label +gen_img = generator(gen_in, signature=gen_signature) + +# Convert generated image to channels_first. +gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) + +# Override intermediate layer. +encoding = latent +ENC_SHAPE = [Z_DIM] + +# Step counter. +inv_step = tf.get_variable('inv_step', initializer=0, trainable=False) + +# Define target image. +IMG_SHAPE = gen_img.get_shape().as_list()[1:] +target = tf.get_variable(name='target', dtype=tf.int32, + shape=[BATCH_SIZE,] + IMG_SHAPE) +target_img = (tf.cast(target, tf.float32) / 255.) * 2.0 - 1. # Norm to [-1, 1]. + +# Monitor relu's activation. +gen_scope = 'module_apply_' + gen_signature + '/' +activation_rate = 1.0 - tf.nn.zero_fraction(tf.get_default_graph()\ + .get_tensor_by_name(gen_scope + params.log_activation_layer)) + +# -------------------------- +# Reconstruction losses. +# -------------------------- +# Mse loss for image comparison. +pix_square_diff = tf.square((target_img - gen_img) / 2.0) +mse_loss = tf.reduce_mean(pix_square_diff) +img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3]) + +# Use custom features for image comparison. +feature_extractor = hub.Module("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1") + +# Convert images from range [-1, 1] channels_first to [0, 1] channels_last. +gen_img_1 = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1]) +target_img_1 = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1]) + +# Convert images to appropriate size for feature extraction. +height, width = hub.get_expected_image_size(feature_extractor) +gen_img_1 = tf.image.resize_images(gen_img_1, [height, width]) +target_img_1 = tf.image.resize_images(target_img_1, [height, width]) + +gen_feat = feature_extractor(dict(images=gen_img_1), as_dict=True, + signature='image_feature_vector')[params.feature_extractor_output] +target_feat = feature_extractor(dict(images=target_img_1), as_dict=True, + signature='image_feature_vector')[params.feature_extractor_output] +feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), + [BATCH_SIZE, -1]) +feat_loss = tf.reduce_mean(feat_square_diff) +img_feat_err = tf.reduce_mean(feat_square_diff, axis=1) + +# -------------------------- +# Regularization losses. +# -------------------------- +norm_dist = tfp.distributions.Normal(0.0, 1.0) +likeli_loss = - tf.reduce_mean(norm_dist.log_prob(latent)) +mode_log_prob = norm_dist.log_prob(0.0) +likeli_loss += mode_log_prob + +# Per image reconstruction error. +img_rec_err = 1.0 * img_mse_err + 1.0 * img_feat_err + +# Batch reconstruction error. +rec_loss = 1.0 * mse_loss + 1.0 * feat_loss + +# Total inversion loss. +inv_loss = rec_loss + 0.1 * likeli_loss + +# -------------------------- +# Optimizer. +# -------------------------- +lrate = tf.train.exponential_decay(0.1, inv_step, + INVERSION_ITERATIONS / 2, 0.1, staircase=True) +trained_params = [latent, label] +optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999) +inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params, + global_step=inv_step) +reinit_optimizer = tf.variables_initializer(optimizer.variables()) + +# -------------------------- +# Noise source. +# -------------------------- +def noise_sampler(): + return np.random.normal(size=[BATCH_SIZE, Z_DIM]) + +def label_init(shape=[BATCH_SIZE, N_CLASS]): + return np.random.uniform(low=0.00, high=0.01, size=shape) + +# -------------------------- +# Dataset. +# -------------------------- + +paths = glob(os.path.join(params.input_dir, '*.jpg')) + \ + glob(os.path.join(params.input_dir, '*.jpeg')) + \ + glob(os.path.join(params.input_dir, '*.png')) +sample_images = [ vs.load_image(fn, 128) for fn in sorted(paths) ] +ACTUAL_NUM_IMGS = sample_images.shape[0] # number of images to be inverted. +print("Number of images: {}".format(ACTUAL_NUM_IMGS)) +NUM_IMGS = ACTUAL_NUM_IMGS + +# pad the image array to match the batch size +while NUM_IMGS % BATCH_SIZE != 0: + sample_images += sample_images[-1] + NUM_IMGS += 1 +sample_images = np.array(sample_images) + +def sample_images_gen(): + for i in range(int(NUM_IMGS / BATCH_SIZE)): + i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE + yield sample_images[i_1:i_2] +image_gen = sample_images_gen() + +assert(NUM_IMGS % BATCH_SIZE == 0) + +# -------------------------- +# Generation. +# -------------------------- +# Start session. +sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) +sess.run(tf.global_variables_initializer()) +sess.run(tf.tables_initializer()) + +# Output file. +out_file = h5py.File(os.path.join(INVERSES_DIR, DATASET_OUT), 'w') +out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, dtype='uint8') +out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32') + +# Gradient descent w.r.t. generator's inputs. +it = 0 +out_pos = 0 +start_time = time.time() + +for image_batch in image_gen: + # Set target. + sess.run(target.assign(image_batch)) + + # Start with a random label + label_batch = label_sampler(BATCH_SIZE) + sess.run(label.assign(label_init())) + + # Start with a random vector + sess.run(latent.assign(noise_sampler())) + + # Init optimizer. + sess.run(inv_step.assign(0)) + sess.run(reinit_optimizer) + + # Main optimization loop. + for _ in range(params.inv_it): + _inv_loss, _mse_loss, _feat_loss, _rec_loss, _likeli_loss,\ + _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, + rec_loss, likeli_loss, lrate, inv_train_op]) + + # Every 100 iterations save logs with training information. + if it % 100 == 0: + # Log losses. + etime = time.time() - start_time + + _act_rate = sess.run(activation_rate) + print('activation_rate={:.4f}'.format(_act_rate)) + log_stats('activation rate', _act_rate, it) + + sys.stdout.flush() + + # Log tensorboard's statistics. + log_stats('total loss', _inv_loss, it) + log_stats('mse loss', _mse_loss, it) + log_stats('feat loss', _feat_loss, it) + log_stats('rec loss', _rec_loss, it) + log_stats('reg loss', _reg_loss, it) + log_stats('dist loss', _dist_loss, it) + log_stats('out pos', out_pos, it) + log_stats('lrate', _lrate, it) + summary_writer.flush() + + gen_images = sess.run(gen_img) + inv_batch = vs.interleave(image_batch[BATCH_SIZE - SAMPLE_SIZE:], + vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:])) + inv_batch = vs.grid_transform(inv_batch) + vs.save_image('{}/progress_{}.png'.format(SAMPLES_DIR, it), inv_batch) + + it += 1 + + # gen_images = sess.run(gen_img) + # gen_images = vs.data2img(gen_images) + label_batch = sess.run(label) + print(label_batch.shape) + + out_images[i:i+BATCH_SIZE] = image_batch + out_labels[i:i+BATCH_SIZE] = label_batch + + out_batch = vs.grid_transform(gen_images[:SAMPLE_SIZE]) + vs.save_image('{}/generated_{}.png'.format(SAMPLES_DIR, i), out_batch) + print('Saved samples for imgs: {}-{}.'.format(i,i+BATCH_SIZE)) + + + + + + + + + + + + + + + + + + + + + + + diff --git a/inversion/interpolation.py b/inversion/interpolation.py new file mode 100644 index 0000000..7133337 --- /dev/null +++ b/inversion/interpolation.py @@ -0,0 +1,190 @@ +# ------------------------------------------------------------------------------ +# Linear interpolation between inverted images and generated images. +# ------------------------------------------------------------------------------ + +import functools +import h5py +import itertools +import numpy as np +import os +import pickle +import params +import scipy +import sys +import tensorflow as tf +import tensorflow_hub as hub +import time +import visualize as vs + +# -------------------------- +# Hyper-parameters. +# -------------------------- +if len(sys.argv) < 2: + sys.exit('Must provide a configuration file.') +params = params.Params(sys.argv[1]) + +# -------------------------- +# Global variables. +# -------------------------- +BATCH_SIZE = params.batch_size +SAMPLE_SIZE = params.sample_size +SAMPLES_DIR = 'interpolation' +INVERSES_DIR = 'inverses' +if not os.path.exists(SAMPLES_DIR): + os.makedirs(SAMPLES_DIR) +if not os.path.exists(INVERSES_DIR): + os.makedirs(INVERSES_DIR) + +# -------------------------- +# Util functions. +# -------------------------- +def interpolate(A, B, num_interps): + alphas = np.linspace(0, 1., num_interps) + if A.shape != B.shape: + raise ValueError('A and B must have the same shape to interpolate.') + return np.array([(1-a)*A + a*B for a in alphas]) + +# One hot encoding for classes. +def one_hot(values): + return np.eye(N_CLASS)[values] + +# Random sampler for classes. +def label_sampler(size=[BATCH_SIZE]): + return np.random.random_integers(low=0, high=N_CLASS-1, size=size) + +def label_hot_sampler(size=[BATCH_SIZE]): + return one_hot(label_sampler(size=size)) + +# -------------------------- +# Load Graph. +# -------------------------- +generator = hub.Module(str(params.generator_path)) + +gen_signature = 'generator' +if 'generator' not in generator.get_signature_names(): + gen_signature = 'default' + +input_info = generator.get_input_info_dict(gen_signature) +COND_GAN = 'y' in input_info + +if COND_GAN: + Z_DIM = input_info['z'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + N_CLASS = input_info['y'].get_shape().as_list()[1] + label = tf.get_variable(name='label', dtype=tf.float32, + shape=[BATCH_SIZE, N_CLASS]) + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_in['y'] = label + gen_img = generator(gen_in, signature=gen_signature) +else: + Z_DIM = input_info['default'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + if (params.generator_fixed_inputs): + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_img = generator(gen_in, signature=gen_signature) + else: + gen_img = generator(latent, signature=gen_signature) + +# Convert generated image to channels_first. +gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) + +# Override intermediate layer. +if params.inv_layer == 'latent': + encoding = latent + ENC_SHAPE = [Z_DIM] +else: + layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer + gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) + ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] + encoding = tf.get_variable(name='encoding', dtype=tf.float32, + shape=[BATCH_SIZE,] + ENC_SHAPE) + tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) + +# Define image shape. +IMG_SHAPE = gen_img.get_shape().as_list()[1:] + +# -------------------------- +# Noise source. +# -------------------------- +def noise_sampler(): + return np.random.normal(size=[BATCH_SIZE, Z_DIM]) + +# -------------------------- +# Dataset. +# -------------------------- +in_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'r') +in_images = in_file['xtrain'] +if COND_GAN: + in_labels = in_file['ytrain'] +in_encoding = in_file['encoding'] +in_latent = in_file['latent'] +NUM_IMGS = in_images.shape[0] # number of images. + +# -------------------------- +# Training. +# -------------------------- +# Start session. +sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) +sess.run(tf.global_variables_initializer()) +sess.run(tf.tables_initializer()) + +for i in range(0, NUM_IMGS, BATCH_SIZE): + # Set label. + if COND_GAN: + sess.run(label.assign(one_hot(in_labels[i:i+BATCH_SIZE]))) + + # Linear interpolation between G_1(z*) and G_1(z*)+delta*. + sample_enc_1 = in_encoding[i:i+BATCH_SIZE] + out_batch = np.ndarray(shape=[8,BATCH_SIZE]+IMG_SHAPE, dtype='uint8') + out_batch[0] = in_images[i:i+BATCH_SIZE] + sess.run(latent.assign(in_latent[i:i+BATCH_SIZE])) + sample_enc_2 = sess.run(gen_encoding) + sample_enc = interpolate(sample_enc_1, sample_enc_2, 7) + for j in range(0,7): + sess.run(encoding.assign(sample_enc[j])) + gen_images = sess.run(gen_img) + gen_images = vs.data2img(gen_images) + out_batch[j+1] = gen_images + + out_batch = np.transpose(out_batch, [1, 0, 2, 3, 4]) + for k in range(BATCH_SIZE): + out_batch_k = vs.seq_transform(out_batch[k]) + # Add white padding. + pad = 20 + out_batch_kk = np.ndarray(shape=[IMG_SHAPE[1], IMG_SHAPE[1]*8+pad, 3], + dtype='uint8') + out_batch_kk[:,:IMG_SHAPE[1],:] = out_batch_k[:,:IMG_SHAPE[1],:] + out_batch_kk[:,IMG_SHAPE[1]:IMG_SHAPE[1]+pad,:] = 255 + out_batch_kk[:,IMG_SHAPE[1]+pad:,:] = out_batch_k[:,IMG_SHAPE[1]:,:] + + vs.save_image('{}/interpolation_delta_{}.png'.format(SAMPLES_DIR, i+k), out_batch_kk) + print('Saved delta interpolation for img: {}.'.format(i+k)) + + # Linear interpolation between G_1(z_random) and G_1(z*)+delta*. + sample_enc_1 = in_encoding[i:i+BATCH_SIZE] + sample_z_1 = in_latent[i:i+BATCH_SIZE] + out_batch = np.ndarray(shape=[8*8,BATCH_SIZE]+IMG_SHAPE, dtype='uint8') + for k in range(8): + sample_z_2 = noise_sampler() + sess.run(latent.assign(sample_z_2)) + sample_enc_2 = sess.run(gen_encoding) + sample_z = interpolate(sample_z_1, sample_z_2, 8) + sample_enc = interpolate(sample_enc_1, sample_enc_2, 8) + for j in range(8): + sess.run(latent.assign(sample_z[j])) + sess.run(encoding.assign(sample_enc[j])) + gen_images = sess.run(gen_img) + gen_images = vs.data2img(gen_images) + out_batch[k*8+j] = gen_images + + out_batch = np.transpose(out_batch, [1, 0, 2, 3, 4]) + for k in range(BATCH_SIZE): + out_batch_k = vs.grid_transform(out_batch[k]) + vs.save_image('{}/interpolation_rand_{}.png'.format(SAMPLES_DIR, i+k), out_batch_k) + print('Saved rand interpolation for img: {}.'.format(i+k)) + +sess.close() diff --git a/inversion/inversion.py b/inversion/inversion.py new file mode 100644 index 0000000..04c033b --- /dev/null +++ b/inversion/inversion.py @@ -0,0 +1,477 @@ +# ------------------------------------------------------------------------------ +# Implementation of the inverse of Generator by Gradient descent w.r.t. +# generator's inputs, for many intermediate layers. +# ------------------------------------------------------------------------------ + +import glob +import h5py +import itertools +import numpy as np +import os +import params +import PIL +import scipy +import sys +import tensorflow as tf +import tensorflow_probability as tfp +import tensorflow_hub as hub +import time +import visualize as vs + +# -------------------------- +# Hyper-parameters. +# -------------------------- +# Expected parameters: +# generator_path: path to generator module. +# generator_fixed_inputs: dictionary of fixed generator's input parameters. +# dataset: name of the dataset (hdf5 file). +# dataset_out: name for the output inverted dataset (hdf5 file). +# General parameters: +# batch_size: number of images inverted at the same time. +# inv_it: number of iterations to invert an image. +# inv_layer: 'latent' or name of the tensor of the custom layer to be inverted. +# lr: learning rate. +# decay_lr: exponential decay on the learning rate. +# decay_n: number of exponential decays on the learning rate. +# custom_grad_relu: replace relus with custom gradient. +# Logging: +# sample_size: number of images included in sampled images. +# save_progress: whether to save intermediate images during optimization. +# log_z_norm: log the norm of different sections of z. +# log_activation_layer: log the percentage of active neurons in this layer. +# Losses: +# mse: use the mean squared error on pixels for image comparison. +# features: use features extracted by a feature extractor for image comparison. +# feature_extractor_path: path to feature extractor module. +# feature_extractor_output: output name from feature extractor. +# likeli_loss: regularization loss on the log likelihood of encodings. +# norm_loss: regularization loss on the norm of encodings. +# dist_loss: whether to include a loss on the dist between g1(z) and enc. +# lambda_mse: coefficient for mse loss. +# lambda_feat: coefficient for features loss. +# lambda_reg: coefficient for regularization loss on latent. +# lambda_dist: coefficient for l1 regularization on delta. +# Latent: +# clipping: whether to clip encoding values after every update. +# stochastic_clipping: whether to consider stochastic clipping. +# clip: clipping bound. +# pretrained_latent: load pre trained fixed latent. +# fixed_z: do not train the latent vector. +# Initialization: +# init_gen_dist: initialize encodings from the generated distribution. +# init_lo: init min value. +# init_hi: init max value. +if len(sys.argv) < 2: + sys.exit('Must provide a configuration file.') +params = params.Params(sys.argv[1]) + +# -------------------------- +# Global directories. +# -------------------------- +BATCH_SIZE = params.batch_size +SAMPLE_SIZE = params.sample_size +LOGS_DIR = 'logs' +SAMPLES_DIR = 'samples' +INVERSES_DIR = 'inverses' +if not os.path.exists(LOGS_DIR): + os.makedirs(LOGS_DIR) +if not os.path.exists(SAMPLES_DIR): + os.makedirs(SAMPLES_DIR) +if not os.path.exists(INVERSES_DIR): + os.makedirs(INVERSES_DIR) + +# -------------------------- +# Util functions. +# -------------------------- +# One hot encoding for classes. +def one_hot(values): + return np.eye(N_CLASS)[values] + +# -------------------------- +# Logging. +# -------------------------- +summary_writer = tf.summary.FileWriter(LOGS_DIR) +def log_stats(name, val, it): + summary = tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=val)]) + summary_writer.add_summary(summary, it) + +# -------------------------- +# Load Graph. +# -------------------------- +generator = hub.Module(str(params.generator_path)) + +gen_signature = 'generator' +if 'generator' not in generator.get_signature_names(): + gen_signature = 'default' + +input_info = generator.get_input_info_dict(gen_signature) +COND_GAN = 'y' in input_info + +if COND_GAN: + Z_DIM = input_info['z'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + N_CLASS = input_info['y'].get_shape().as_list()[1] + label = tf.get_variable(name='label', dtype=tf.float32, + shape=[BATCH_SIZE, N_CLASS]) + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_in['y'] = label + gen_img = generator(gen_in, signature=gen_signature) +else: + Z_DIM = input_info['default'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + if (params.generator_fixed_inputs): + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_img = generator(gen_in, signature=gen_signature) + else: + gen_img = generator(latent, signature=gen_signature) + +# Convert generated image to channels_first. +gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) + +# Override intermediate layer. +if params.inv_layer == 'latent': + encoding = latent + ENC_SHAPE = [Z_DIM] +else: + layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer + gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) + ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] + encoding = tf.get_variable(name='encoding', dtype=tf.float32, + shape=[BATCH_SIZE,] + ENC_SHAPE) + tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) + +# Step counter. +inv_step = tf.get_variable('inv_step', initializer=0, trainable=False) + +# Define target image. +IMG_SHAPE = gen_img.get_shape().as_list()[1:] +target = tf.get_variable(name='target', dtype=tf.int32, + shape=[BATCH_SIZE,] + IMG_SHAPE) +target_img = (tf.cast(target, tf.float32) / 255.) * 2.0 - 1. # Norm to [-1, 1]. + +# Custom Gradient for Relus. +if params.custom_grad_relu: + grad_lambda = tf.train.exponential_decay(0.1, inv_step, params.inv_it / 5, + 0.1, staircase=False) + @tf.custom_gradient + def relu_custom_grad(x): + def grad(dy): + return tf.where(x >= 0, dy, + grad_lambda*tf.where(dy < 0, dy, tf.zeros_like(dy))) + return tf.nn.relu(x), grad + + gen_scope = 'module_apply_' + gen_signature + '/' + for op in tf.get_default_graph().get_operations(): + if 'Relu' in op.name and gen_scope in op.name: + assert len(op.inputs) == 1 + assert len(op.outputs) == 1 + new_out = relu_custom_grad(op.inputs[0]) + tf.contrib.graph_editor.swap_ts(op.outputs[0], new_out) + +# Operations to clip the values of the encodings. +if params.clipping or params.stochastic_clipping: + assert params.clip >= 0 + if params.stochastic_clipping: + new_enc = tf.where(tf.abs(latent) >= params.clip, + tf.random.uniform([BATCH_SIZE, Z_DIM], minval=-params.clip, + maxval=params.clip), latent) + else: + new_enc = tf.clip_by_value(latent, -params.clip, params.clip) + clip_latent = tf.assign(latent, new_enc) + +# Monitor relu's activation. +if params.log_activation_layer: + gen_scope = 'module_apply_' + gen_signature + '/' + activation_rate = 1.0 - tf.nn.zero_fraction(tf.get_default_graph()\ + .get_tensor_by_name(gen_scope + params.log_activation_layer)) + +# -------------------------- +# Reconstruction losses. +# -------------------------- +# Mse loss for image comparison. +if params.mse: + pix_square_diff = tf.square((target_img - gen_img) / 2.0) + mse_loss = tf.reduce_mean(pix_square_diff) + img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3]) +else: + mse_loss = tf.constant(0.0) + img_mse_err = tf.constant(0.0) + +# Use custom features for image comparison. +if params.features: + feature_extractor = hub.Module(str(params.feature_extractor_path)) + + # Convert images from range [-1, 1] channels_first to [0, 1] channels_last. + gen_img_1 = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1]) + target_img_1 = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1]) + + # Convert images to appropriate size for feature extraction. + height, width = hub.get_expected_image_size(feature_extractor) + gen_img_1 = tf.image.resize_images(gen_img_1, [height, width]) + target_img_1 = tf.image.resize_images(target_img_1, [height, width]) + + gen_feat = feature_extractor(dict(images=gen_img_1), as_dict=True, + signature='image_feature_vector')[params.feature_extractor_output] + target_feat = feature_extractor(dict(images=target_img_1), as_dict=True, + signature='image_feature_vector')[params.feature_extractor_output] + feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), + [BATCH_SIZE, -1]) + feat_loss = tf.reduce_mean(feat_square_diff) + img_feat_err = tf.reduce_mean(feat_square_diff, axis=1) +else: + feat_loss = tf.constant(0.0) + img_feat_err = tf.constant(0.0) + +# -------------------------- +# Regularization losses. +# -------------------------- +# Loss on the norm of the encoding. +if params.norm_loss: + dim = 20 + chi2_dist = tfp.distributions.Chi2(dim) + mode = dim - 2 + mode_log_prob = chi2_dist.log_prob(mode) + norm_loss = 0.0 + for i in range(int(Z_DIM / dim)): + squared_l2 = tf.reduce_sum(tf.square(latent[:,i*dim:(i+1)*dim]), axis=1) + over_mode = tf.nn.relu(squared_l2 - mode) + norm_loss -= tf.reduce_mean(chi2_dist.log_prob(mode + over_mode)) + norm_loss += mode_log_prob +else: + norm_loss = tf.constant(0.0) + +# Loss on the likelihood of the encoding. +if params.likeli_loss: + norm_dist = tfp.distributions.Normal(0.0, 1.0) + likeli_loss = - tf.reduce_mean(norm_dist.log_prob(latent)) + mode_log_prob = norm_dist.log_prob(0.0) + likeli_loss += mode_log_prob +else: + likeli_loss = tf.constant(0.0) + +# Regularization loss. +reg_loss = norm_loss + likeli_loss + +# Loss on the l1 distance between gen_encoding and inverted encoding. +if params.dist_loss: + dist_loss = tf.reduce_mean(tf.abs(encoding - gen_encoding)) +else: + dist_loss = tf.constant(0.0) + +# Per image reconstruction error. +img_rec_err = params.lambda_mse * img_mse_err\ + + params.lambda_feat * img_feat_err + +# Batch reconstruction error. +rec_loss = params.lambda_mse * mse_loss + params.lambda_feat * feat_loss + +# Total inversion loss. +inv_loss = rec_loss + params.lambda_reg * reg_loss\ + + params.lambda_dist * dist_loss + +# -------------------------- +# Optimizer. +# -------------------------- +if params.decay_lr: + lrate = tf.train.exponential_decay(params.lr, inv_step, + params.inv_it / params.decay_n, 0.1, staircase=True) +else: + lrate = tf.constant(params.lr) +trained_params = [encoding] if params.fixed_z else [latent, encoding] +optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999) +inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params, + global_step=inv_step) +reinit_optimizer = tf.variables_initializer(optimizer.variables()) + +# -------------------------- +# Noise source. +# -------------------------- +def noise_sampler(): + return np.random.normal(size=[BATCH_SIZE, Z_DIM]) + +def small_init(shape=[BATCH_SIZE, Z_DIM]): + return np.random.uniform(low=params.init_lo, high=params.init_hi, size=shape) + +# -------------------------- +# Dataset. +# -------------------------- +if params.dataset.endswith('.hdf5'): + in_file = h5py.File(params.dataset, 'r') + sample_images = in_file['xtrain'] + if COND_GAN: + sample_labels = in_file['ytrain'] + NUM_IMGS = sample_images.shape[0] # number of images to be inverted. + print("Number of images: {}".format(NUM_IMGS)) + def sample_images_gen(): + for i in range(int(NUM_IMGS / BATCH_SIZE)): + i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE + if COND_GAN: + yield sample_images[i_1:i_2], sample_labels[i_1:i_2] + else: + yield sample_images[i_1:i_2], np.zeros(BATCH_SIZE) + image_gen = sample_images_gen() + if 'latent' in in_file: + sample_latents = in_file['latent'] + def sample_latent_gen(): + for i in range(int(NUM_IMGS / BATCH_SIZE)): + i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE + yield sample_latents[i_1:i_2] + latent_gen = sample_latent_gen() + assert(NUM_IMGS % BATCH_SIZE == 0) +else: + sys.exit('Unknown dataset {}.'.format(params.dataset)) + +NUM_IMGS -= NUM_IMGS % BATCH_SIZE + +# -------------------------- +# Training. +# -------------------------- +# Start session. +sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) +sess.run(tf.global_variables_initializer()) +sess.run(tf.tables_initializer()) + +# Output file. +out_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'w') +out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, + dtype='uint8') +out_enc = out_file.create_dataset('encoding', [NUM_IMGS,] + ENC_SHAPE) +out_lat = out_file.create_dataset('latent', [NUM_IMGS, Z_DIM]) +if COND_GAN: + out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32') +out_err = out_file.create_dataset('err', (NUM_IMGS,)) + +# Gradient descent w.r.t. generator's inputs. +it = 0 +out_pos = 0 +start_time = time.time() + +for image_batch, label_batch in image_gen: + + # Save target. + sess.run(target.assign(image_batch)) + if COND_GAN: + sess.run(label.assign(one_hot(label_batch))) + + # Initialize encodings to random values. + if params.pre_trained_latent: + sess.run(latent.assign(latent_gen.next())) + if params.inv_layer != 'latent': + sess.run(encoding.assign(gen_encoding)) + else: + if params.init_gen_dist: + sess.run(latent.assign(noise_sampler())) + if params.inv_layer != 'latent': + sess.run(encoding.assign(gen_encoding)) + else: + sess.run(latent.assign(small_init())) + if params.inv_layer != 'latent': + sess.run(encoding.assign(small_init(shape=[BATCH_SIZE,] + ENC_SHAPE))) + + # Init optimizer. + sess.run(inv_step.assign(0)) + sess.run(reinit_optimizer) + + # Main optimization loop. + print("Total iterations: {}".format(params.inv_it)) + for _ in range(params.inv_it): + + _inv_loss, _mse_loss, _feat_loss, _rec_loss, _reg_loss, _dist_loss,\ + _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, + rec_loss, reg_loss, dist_loss, lrate, inv_train_op]) + + if params.clipping or params.stochastic_clipping: + sess.run(clip_latent) + + # Every 100 iterations save logs with training information. + if it % 100 == 99: + # Log losses. + etime = time.time() - start_time + print('It [{:8d}] time [{:5.1f}] total [{:.4f}] mse [{:.4f}] ' + 'feat [{:.4f}] rec [{:.4f}] reg [{:.4f}] dist [{:.4f}] ' + 'lr [{:.4f}]'.format(it, etime, _inv_loss, _mse_loss, + _feat_loss, _rec_loss, _reg_loss, _dist_loss, _lrate)) + + if params.log_z_norm: + _lat = sess.run(latent) + dim = 20 if Z_DIM == 120 else Z_DIM + for i in range(int(Z_DIM/dim)): + _subset = _lat[:,i*dim:(i+1)*dim] + print('section {:1d}: norm={:.4f} (exp={:.4f}) min={:.4f} max={:.4f}'\ + .format(i, np.mean(np.linalg.norm(_subset, axis=1)), + np.sqrt(dim-2), np.min(_subset), np.max(_subset))) + + if params.log_activation_layer: + _act_rate = sess.run(activation_rate) + print('activation_rate={:.4f}'.format(_act_rate)) + log_stats('activation rate', _act_rate, it) + + sys.stdout.flush() + + # Log tensorboard's statistics. + log_stats('total loss', _inv_loss, it) + log_stats('mse loss', _mse_loss, it) + log_stats('feat loss', _feat_loss, it) + log_stats('rec loss', _rec_loss, it) + log_stats('reg loss', _reg_loss, it) + log_stats('dist loss', _dist_loss, it) + log_stats('out pos', out_pos, it) + log_stats('lrate', _lrate, it) + summary_writer.flush() + + # Save target images and reconstructions. + if params.save_progress: + assert SAMPLE_SIZE <= BATCH_SIZE + gen_images = sess.run(gen_img) + inv_batch = vs.interleave(image_batch[BATCH_SIZE - SAMPLE_SIZE:], + vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:])) + inv_batch = vs.grid_transform(inv_batch) + vs.save_image('{}/progress_{}.png'.format(SAMPLES_DIR, it), inv_batch) + + # Save linear interpolation between the actual and generated encodings. + if params.dist_loss and it % 1000 == 999: + enc_batch, gen_enc = sess.run([encoding, gen_encoding]) + for j in range(10): + custom_enc = gen_enc * (1-(j/10.0)) + enc_batch * (j/10.0) + sess.run(encoding.assign(custom_enc)) + gen_images = sess.run(gen_img) + inv_batch = vs.interleave(image_batch[BATCH_SIZE - SAMPLE_SIZE:], + vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:])) + inv_batch = vs.grid_transform(inv_batch) + vs.save_image('{}/progress_{}_lat_{}.png'.format(SAMPLES_DIR,it,j), + inv_batch) + sess.run(encoding.assign(enc_batch)) + + # It counter. + it += 1 + + # Save samples of inverted images. + if SAMPLE_SIZE > 0: + assert SAMPLE_SIZE <= BATCH_SIZE + gen_images = sess.run(gen_img) + inv_batch = vs.interleave(image_batch[BATCH_SIZE - SAMPLE_SIZE:], + vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:])) + inv_batch = vs.grid_transform(inv_batch) + vs.save_image('{}/{}.png'.format(SAMPLES_DIR, out_pos), inv_batch) + print('Saved samples for out_pos: {}.'.format(out_pos)) + + # Save images that are ready. + latent_batch, enc_batch, rec_err_batch =\ + sess.run([latent, encoding, img_rec_err]) + out_lat[out_pos:out_pos+BATCH_SIZE] = latent_batch + out_enc[out_pos:out_pos+BATCH_SIZE] = enc_batch + out_images[out_pos:out_pos+BATCH_SIZE] = image_batch + if COND_GAN: + out_labels[out_pos:out_pos+BATCH_SIZE] = label_batch + out_err[out_pos:out_pos+BATCH_SIZE] = rec_err_batch + out_pos += BATCH_SIZE + +print('Mean reconstruction error: {}'.format(np.mean(out_err))) +print('Stdev reconstruction error: {}'.format(np.std(out_err))) +print('End of inversion.') +out_file.close() +sess.close() diff --git a/inversion/params.py b/inversion/params.py new file mode 100644 index 0000000..dd9a358 --- /dev/null +++ b/inversion/params.py @@ -0,0 +1,27 @@ +# ------------------------------------------------------------------------------ +# Util class for hyperparams. +# ------------------------------------------------------------------------------ + +import json + +class Params(): + """Class that loads hyperparameters from a json file.""" + + def __init__(self, json_path): + self.update(json_path) + + def save(self, json_path): + """Saves parameters to json file.""" + with open(json_path, 'w') as f: + json.dump(self.__dict__, f, indent=4) + + def update(self, json_path): + """Loads parameters from json file.""" + with open(json_path) as f: + params = json.load(f) + self.__dict__.update(params) + + @property + def dict(self): + """Gives dict-like access to Params instance.""" + return self.__dict__ diff --git a/inversion/params_dense-512.json b/inversion/params_dense-512.json new file mode 100644 index 0000000..cbe7a3e --- /dev/null +++ b/inversion/params_dense-512.json @@ -0,0 +1,40 @@ +{ + "decay_n": 2, + "features": true, + "clip": 1.0, + "stochastic_clipping": false, + "clipping": false, + "dataset": "inverses/dataset.encodings.hdf5", + "inv_layer": "Generator_2/G_Z/Reshape:0", + "decay_lr": true, + "inv_it": 15000, + "generator_path": "https://tfhub.dev/deepmind/biggan-512/2", + "attention_map_layer": "Generator_2/attention/Softmax:0", + "pre_trained_latent": true, + "lambda_dist": 10.0, + "likeli_loss": false, + "init_hi": 0.001, + "lr": 0.01, + "norm_loss": false, + "generator_fixed_inputs": { + "truncation": 1.0 + }, + "log_z_norm": false, + "feature_extractor_path": "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1", + "mse": true, + "custom_grad_relu": false, + "random_label": false, + "lambda_feat": 1.0, + "init_gen_dist": false, + "log_activation_layer": "Generator_2/GBlock/Relu:0", + "batch_size": 4, + "fixed_z": true, + "feature_extractor_output": "InceptionV3/Mixed_7a", + "init_lo": -0.001, + "lambda_mse": 1.0, + "lambda_reg": 0.1, + "dist_loss": true, + "sample_size": 4, + "out_dataset": "dataset.encodings.encodings.hdf5", + "save_progress": true +} diff --git a/inversion/params_dense.json b/inversion/params_dense.json new file mode 100644 index 0000000..73bfbe8 --- /dev/null +++ b/inversion/params_dense.json @@ -0,0 +1,40 @@ +{ + "decay_n": 2, + "features": true, + "clip": 1.0, + "stochastic_clipping": false, + "clipping": false, + "dataset": "inverses/dataset.encodings.hdf5", + "inv_layer": "Generator_2/G_Z/Reshape:0", + "decay_lr": true, + "inv_it": 15000, + "generator_path": "https://tfhub.dev/deepmind/biggan-128/2", + "attention_map_layer": "Generator_2/attention/Softmax:0", + "pre_trained_latent": true, + "lambda_dist": 10.0, + "likeli_loss": false, + "init_hi": 0.001, + "lr": 0.01, + "norm_loss": false, + "generator_fixed_inputs": { + "truncation": 1.0 + }, + "log_z_norm": false, + "feature_extractor_path": "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1", + "mse": true, + "custom_grad_relu": false, + "random_label": false, + "lambda_feat": 1.0, + "init_gen_dist": false, + "log_activation_layer": "Generator_2/GBlock/Relu:0", + "batch_size": 25, + "fixed_z": true, + "feature_extractor_output": "InceptionV3/Mixed_7a", + "init_lo": -0.001, + "lambda_mse": 1.0, + "lambda_reg": 0.1, + "dist_loss": true, + "sample_size": 25, + "out_dataset": "dataset.encodings.encodings.hdf5", + "save_progress": true +} diff --git a/inversion/params_latent-512.json b/inversion/params_latent-512.json new file mode 100644 index 0000000..5fa0c19 --- /dev/null +++ b/inversion/params_latent-512.json @@ -0,0 +1,40 @@ +{ + "decay_n": 2, + "features": true, + "clip": 1.0, + "stochastic_clipping": false, + "clipping": false, + "dataset": "inverses/dataset.hdf5", + "inv_layer": "latent", + "decay_lr": true, + "inv_it": 15000, + "generator_path": "https://tfhub.dev/deepmind/biggan-512/2", + "attention_map_layer": "Generator_2/attention/Softmax:0", + "pre_trained_latent": false, + "lambda_dist": 0.0, + "likeli_loss": true, + "init_hi": 0.001, + "lr": 0.1, + "norm_loss": false, + "generator_fixed_inputs": { + "truncation": 1.0 + }, + "log_z_norm": true, + "feature_extractor_path": "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1", + "mse": true, + "custom_grad_relu": false, + "random_label": false, + "lambda_feat": 1.0, + "init_gen_dist": false, + "log_activation_layer": "Generator_2/GBlock/Relu:0", + "batch_size": 4, + "fixed_z": false, + "feature_extractor_output": "InceptionV3/Mixed_7a", + "init_lo": -0.001, + "lambda_mse": 1.0, + "lambda_reg": 0.1, + "dist_loss": false, + "sample_size": 4, + "out_dataset": "dataset.encodings.hdf5", + "save_progress": true +} diff --git a/inversion/params_latent.json b/inversion/params_latent.json new file mode 100644 index 0000000..4642b52 --- /dev/null +++ b/inversion/params_latent.json @@ -0,0 +1,40 @@ +{ + "decay_n": 2, + "features": true, + "clip": 1.0, + "stochastic_clipping": false, + "clipping": false, + "dataset": "inverses/dataset.hdf5", + "inv_layer": "latent", + "decay_lr": true, + "inv_it": 15000, + "generator_path": "https://tfhub.dev/deepmind/biggan-128/2", + "attention_map_layer": "Generator_2/attention/Softmax:0", + "pre_trained_latent": false, + "lambda_dist": 0.0, + "likeli_loss": true, + "init_hi": 0.001, + "lr": 0.1, + "norm_loss": false, + "generator_fixed_inputs": { + "truncation": 1.0 + }, + "log_z_norm": true, + "feature_extractor_path": "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1", + "mse": true, + "custom_grad_relu": false, + "random_label": false, + "lambda_feat": 1.0, + "init_gen_dist": false, + "log_activation_layer": "Generator_2/GBlock/Relu:0", + "batch_size": 25, + "fixed_z": false, + "feature_extractor_output": "InceptionV3/Mixed_7a", + "init_lo": -0.001, + "lambda_mse": 1.0, + "lambda_reg": 0.1, + "dist_loss": false, + "sample_size": 25, + "out_dataset": "dataset.encodings.hdf5", + "save_progress": true +} diff --git a/inversion/random_sample-512.json b/inversion/random_sample-512.json new file mode 100644 index 0000000..4c5dbd0 --- /dev/null +++ b/inversion/random_sample-512.json @@ -0,0 +1,12 @@ +{ + "generator_path": "https://tfhub.dev/deepmind/biggan-512/2", + "batch_size": 16, + "sample_size": 16, + "custom_label": 0, + "dataset_out": "dataset.hdf5", + "generator_fixed_inputs": { + "truncation": 1.0 + }, + "num_imgs": 16, + "random_label": true +} diff --git a/inversion/random_sample.json b/inversion/random_sample.json new file mode 100644 index 0000000..c93aa1f --- /dev/null +++ b/inversion/random_sample.json @@ -0,0 +1,12 @@ +{ + "generator_path": "https://tfhub.dev/deepmind/biggan-128/2", + "batch_size": 20, + "sample_size": 20, + "custom_label": 0, + "dataset_out": "dataset.hdf5", + "generator_fixed_inputs": { + "truncation": 1.0 + }, + "num_imgs": 1000, + "random_label": true +} diff --git a/inversion/random_sample.py b/inversion/random_sample.py new file mode 100644 index 0000000..61cac9c --- /dev/null +++ b/inversion/random_sample.py @@ -0,0 +1,144 @@ +# ------------------------------------------------------------------------------ +# Generate random samples of the generator and save the images to a hdf5 file. +# ------------------------------------------------------------------------------ + +import h5py +import numpy as np +import os +import params +import sys +import tensorflow as tf +import tensorflow_hub as hub +import time +import visualize as vs + +# -------------------------- +# Hyper-parameters. +# -------------------------- +# Expected parameters: +# generator_path: path to generator module. +# generator_fixed_inputs: dictionary of fixed generator's input parameters. +# dataset_out: name for the output created dataset (hdf5 file). +# General parameters: +# batch_size: number of images generated at the same time. +# random_label: choose random labels. +# num_imgs: number of instances to generate. +# custom_label: custom label to be fixed. +# Logging: +# sample_size: number of images included in sampled images. +if len(sys.argv) < 2: + sys.exit('Must provide a configuration file.') +params = params.Params(sys.argv[1]) + +# -------------------------- +# Hyper-parameters. +# -------------------------- +# General parameters. +BATCH_SIZE = params.batch_size +SAMPLE_SIZE = params.sample_size +assert SAMPLE_SIZE <= BATCH_SIZE +NUM_IMGS = params.num_imgs + +# -------------------------- +# Global directories. +# -------------------------- +SAMPLES_DIR = 'random_samples' +INVERSES_DIR = 'inverses' +if not os.path.exists(SAMPLES_DIR): + os.makedirs(SAMPLES_DIR) +if not os.path.exists(INVERSES_DIR): + os.makedirs(INVERSES_DIR) + +# -------------------------- +# Util functions. +# -------------------------- +def one_hot(values): + return np.eye(N_CLASS)[values] + +def label_sampler(size=1): + return np.random.random_integers(low=0, high=N_CLASS-1, size=size) + +# -------------------------- +# Load Graph. +# -------------------------- +generator = hub.Module(str(params.generator_path)) + +gen_signature = 'generator' +if 'generator' not in generator.get_signature_names(): + gen_signature = 'default' + +input_info = generator.get_input_info_dict(gen_signature) +COND_GAN = 'y' in input_info + +if COND_GAN: + Z_DIM = input_info['z'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + N_CLASS = input_info['y'].get_shape().as_list()[1] + label = tf.get_variable(name='label', dtype=tf.float32, + shape=[BATCH_SIZE, N_CLASS]) + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_in['y'] = label + gen_img = generator(gen_in, signature=gen_signature) +else: + Z_DIM = input_info['default'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + if (params.generator_fixed_inputs): + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_img = generator(gen_in, signature=gen_signature) + else: + gen_img = generator(latent, signature=gen_signature) + +# Convert generated image to channels_first. +gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) + +# Define image shape. +IMG_SHAPE = gen_img.get_shape().as_list()[1:] + +# -------------------------- +# Noise source. +# -------------------------- +def noise_sampler(): + return np.random.normal(size=[BATCH_SIZE, Z_DIM]) + +# -------------------------- +# Generation. +# -------------------------- +# Start session. +sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) +sess.run(tf.global_variables_initializer()) +sess.run(tf.tables_initializer()) + +# Output file. +out_file = h5py.File(os.path.join(INVERSES_DIR, params.dataset_out), 'w') +out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, + dtype='uint8') +if COND_GAN: + out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32') + +for i in range(0, NUM_IMGS, BATCH_SIZE): + n_encs = min(BATCH_SIZE, NUM_IMGS - i) + + if COND_GAN: + if params.random_label: + label_batch = label_sampler(BATCH_SIZE) + else: + label_batch = [params.custom_label]*BATCH_SIZE + sess.run(label.assign(one_hot(label_batch))) + + sess.run(latent.assign(noise_sampler())) + + gen_images = sess.run(gen_img) + + gen_images = vs.data2img(gen_images) + + out_images[i:i+n_encs] = gen_images[:n_encs] + if COND_GAN: + out_labels[i:i+n_encs] = label_batch[:n_encs] + + out_batch = vs.grid_transform(gen_images[:SAMPLE_SIZE]) + vs.save_image('{}/generated_{}.png'.format(SAMPLES_DIR, i), out_batch) + print('Saved samples for imgs: {}-{}.'.format(i,i+n_encs)) diff --git a/inversion/segmentation.py b/inversion/segmentation.py new file mode 100644 index 0000000..42f2f9e --- /dev/null +++ b/inversion/segmentation.py @@ -0,0 +1,191 @@ +# ------------------------------------------------------------------------------ +# Cluster the attention map of inverted images for unsupervised segmentation. +# ------------------------------------------------------------------------------ + +import h5py +import numpy as np +import os +import params +import scipy +import scipy.cluster.hierarchy +from sklearn.cluster import AgglomerativeClustering +import sys +import tensorflow as tf +import tensorflow_hub as hub +import time +import visualize as vs + +if len(sys.argv) < 2: + sys.exit('Must provide a configuration file.') + +params = params.Params(sys.argv[1]) +params.batch_size = 1 +params.sample_size = 1 + +# -------------------------- +# Global directories. +# -------------------------- +BATCH_SIZE = params.batch_size +SAMPLE_SIZE = params.sample_size +SAMPLES_DIR = 'attention' +INVERSES_DIR = 'inverses' +if not os.path.exists(SAMPLES_DIR): + os.makedirs(SAMPLES_DIR) + +# -------------------------- +# Util functions. +# -------------------------- +# One hot encoding for classes. +def one_hot(values): + return np.eye(N_CLASS)[values] + +def segment_img(diss_matrix, n_clusters): + # Cluster image based on the information from the attention map. + clustering = AgglomerativeClustering(n_clusters=n_clusters, + affinity='precomputed', linkage='average') + clustering.fit(diss_matrix) + labels = clustering.labels_ + + # Upsample segmentation (from 64x64 to 128x128) and create an image where each + # segment has the average color of its members. + labels = np.broadcast_to(labels.reshape(64, 1, 64, 1), (64, 2, 64, 2))\ + .reshape(128*128) + labels = np.eye(labels.max() + 1)[labels] + cluster_col = np.matmul(labels.T, + np.transpose(_gen_img, [0, 2, 3, 1]).reshape(128*128, 3)) + cluster_count = labels.T.sum(axis=1).reshape(-1, 1) + labels_img = np.matmul(labels, cluster_col) / np.matmul(labels, cluster_count) + labels_img = np.transpose(labels_img, [1, 0]).reshape(1,3,128,128) + return vs.data2img(labels_img) + +# -------------------------- +# Load Graph. +# -------------------------- +generator = hub.Module(str(params.generator_path)) + +gen_signature = 'generator' +if 'generator' not in generator.get_signature_names(): + gen_signature = 'default' + +input_info = generator.get_input_info_dict(gen_signature) +COND_GAN = 'y' in input_info + +if COND_GAN: + Z_DIM = input_info['z'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + N_CLASS = input_info['y'].get_shape().as_list()[1] + label = tf.get_variable(name='label', dtype=tf.float32, + shape=[BATCH_SIZE, N_CLASS]) + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_in['y'] = label + gen_img = generator(gen_in, signature=gen_signature) +else: + Z_DIM = input_info['default'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + if (params.generator_fixed_inputs): + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_img = generator(gen_in, signature=gen_signature) + else: + gen_img = generator(latent, signature=gen_signature) + +# Convert generated image to channels_first. +gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) + +# Override intermediate layer. +if params.inv_layer == 'latent': + encoding = latent + ENC_SHAPE = [Z_DIM] +else: + layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer + gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) + ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] + encoding = tf.get_variable(name='encoding', dtype=tf.float32, + shape=[BATCH_SIZE,] + ENC_SHAPE) + tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) + +# Get attention map. +att_map_name = 'module_apply_' + gen_signature + '/' + params.attention_map_layer +att_map = tf.get_default_graph().get_tensor_by_name(att_map_name) + +# Define image shape. +IMG_SHAPE = gen_img.get_shape().as_list()[1:] + +# -------------------------- +# Dataset. +# -------------------------- +if params.out_dataset.endswith('.hdf5'): + in_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'r') + sample_images = in_file['xtrain'] + if COND_GAN: + sample_labels = in_file['ytrain'] + sample_latents = in_file['latent'] + sample_encodings = in_file['encoding'] + NUM_IMGS = sample_images.shape[0] # number of images. + def sample_images_gen(): + for i in xrange(NUM_IMGS / BATCH_SIZE): + i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE + if COND_GAN: + label_batch = sample_labels[i_1:i_2] + else: + label_batch = np.zeros(BATCH_SIZE) + yield sample_images[i_1:i_2], label_batch, sample_latents[i_1:i_2],\ + sample_encodings[i_1:i_2] + image_gen = sample_images_gen() +else: + sys.exit('Unknown dataset {}.'.format(params.out_dataset)) + +NUM_IMGS -= NUM_IMGS % BATCH_SIZE + +# -------------------------- +# Training. +# -------------------------- +# Start session. +sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) +sess.run(tf.global_variables_initializer()) +sess.run(tf.tables_initializer()) + +# Export attention map for reconstructed images. +it = 0 +out_pos = 0 +start_time = time.time() + +for image_batch, label_batch, lat_batch, enc_batch in image_gen: + + # Set target label. + if COND_GAN: + sess.run(label.assign(one_hot(label_batch))) + + # Initialize encodings. + sess.run(latent.assign(lat_batch)) + sess.run(encoding.assign(enc_batch)) + + # Get attention map. + _att_map, _gen_img = sess.run([att_map, gen_img]) + + # Upsampling (from 32x32 to 64x64). + _att_map = np.broadcast_to(_att_map.reshape(64,64,32,1,32,1), + (64,64,32,2,32,2)).reshape(4096,4096) + + # Define dissimilarity matrix. + dissimilarity = 1.0 - (_att_map + _att_map.T) / 2.0 + dissimilarity *= (np.ones((4096,4096)) - np.identity(4096)) + + # Segment the image with different number of clusters. + seg_img_8 = segment_img(dissimilarity, 8) + seg_img_20 = segment_img(dissimilarity, 20) + seg_img_40 = segment_img(dissimilarity, 40) + + # Save segmentation. + out_batch_1 = vs.interleave(image_batch, seg_img_20) + out_batch_2 = vs.interleave(seg_img_8, seg_img_40) + out_batch = vs.interleave(out_batch_1, out_batch_2) + out_batch = vs.seq_transform(out_batch) + vs.save_image('{}/segmented_img_{}.png'.format(SAMPLES_DIR, it), out_batch) + + it += 1 + +sess.close() diff --git a/inversion/visualize.py b/inversion/visualize.py new file mode 100644 index 0000000..07aea2d --- /dev/null +++ b/inversion/visualize.py @@ -0,0 +1,88 @@ +# ------------------------------------------------------------------------------ +# Util functions to visualize images. +# ------------------------------------------------------------------------------ + +import numpy as np +import cv2 as cv +from PIL import Image + +def split(x): + assert type(x) == int + t = int(np.floor(np.sqrt(x))) + for a in range(t, 0, -1): + if x % a == 0: + return a, x / a + +def grid_transform(x): + n, c, h, w = x.shape + a, b = split(n) + x = np.transpose(x, [0, 2, 3, 1]) + x = np.reshape(x, [int(a), int(b), int(h), int(w), int(c)]) + x = np.transpose(x, [0, 2, 1, 3, 4]) + x = np.reshape(x, [int(a * h), int(b * w), int(c)]) + if x.shape[2] == 1: + x = np.squeeze(x, axis=2) + return x + +def seq_transform(x): + n, c, h, w = x.shape + x = np.transpose(x, [2, 0, 3, 1]) + x = np.reshape(x, [h, n * w, c]) + return x + +# Converts image pixels from range [-1, 1] to [0, 255]. +def data2img(data): + rescaled = np.divide(data + 1.0, 2.0) * 255. + rescaled = np.clip(rescaled, 0, 255) + return np.rint(rescaled).astype('uint8') + +def interleave(a, b): + res = np.empty([a.shape[0] + b.shape[0]] + list(a.shape[1:]), dtype=a.dtype) + res[0::2] = a + res[1::2] = b + return res + +def save_image(filepath, img): + pilimg = Image.fromarray(img) + pilimg.save(filepath) + +def imread(filename): + img = cv.imread(filename, cv.IMREAD_UNCHANGED) + if img is not None: + if len(img.shape) > 2: + img = img[...,::-1] + return img + +def imconvert_float32(im): + im = np.float32(im) + im = (im / 256) * 2.0 - 1 + return im + +def load_image(opt_fp_in, opt_dims=128): + target_im = imread(opt_fp_in) + w = target_im.shape[1] + h = target_im.shape[0] + if w <= h: + scale = opt_dims / w + else: + scale = opt_dims / h + target_im = cv.resize(target_im,(0,0), fx=scale, fy=scale) + w = target_im.shape[1] + h = target_im.shape[0] + + x0 = 0 + x1 = opt_dims + y0 = 0 + y1 = opt_dims + if w > opt_dims: + x0 += int((w - opt_dims) / 2) + x1 += x0 + if h > opt_dims: + y0 += int((h - opt_dims) / 2) + y1 += y0 + phi_target = imconvert_float32(target_im) + phi_target = phi_target[y0:y1,x0:x1] + if phi_target.shape[2] == 4: + phi_target = phi_target[:,:,1:4] + phi_target = np.expand_dims(phi_target, 0) + return phi_target |
