diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2020-01-06 18:09:58 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2020-01-06 18:09:58 +0100 |
| commit | 4d0db1188bc94970550c02ff55f16110c5e86700 (patch) | |
| tree | abca403c0972d1762834c7fb6b33423f26b52e24 | |
| parent | 9560db910b876ba5a249f3262bd4e05aa3fa2c2e (diff) | |
new inversion placeholder
| -rw-r--r-- | cli/app/commands/biggan/search_class.py | 2 | ||||
| -rw-r--r-- | inversion/image_inversion_placeholder.py | 495 |
2 files changed, 496 insertions, 1 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py index b681e1f..ab040c0 100644 --- a/cli/app/commands/biggan/search_class.py +++ b/cli/app/commands/biggan/search_class.py @@ -151,7 +151,7 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la except KeyboardInterrupt: pass - z_guess, y_guess = sess.run(input_z, input_y) + z_guess, y_guess = sess.run([input_z, input_y]) out_images[index] = phi_target_for_inversion out_labels[index] = y_guess out_latent[index] = z_guess diff --git a/inversion/image_inversion_placeholder.py b/inversion/image_inversion_placeholder.py new file mode 100644 index 0000000..64929cc --- /dev/null +++ b/inversion/image_inversion_placeholder.py @@ -0,0 +1,495 @@ +# ------------------------------------------------------------------------------ +# Implementation of the inverse of Generator by Gradient descent w.r.t. +# generator's inputs, for many intermediate layers. +# ------------------------------------------------------------------------------ + +import glob +import h5py +import itertools +import numpy as np +import os +import params +import PIL +import scipy +import sys +import tensorflow as tf +import tensorflow_probability as tfp +import tensorflow_hub as hub +import time +import visualize as vs +tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) + +# -------------------------- +# Hyper-parameters. +# -------------------------- +# Expected parameters: +# generator_path: path to generator module. +# generator_fixed_inputs: dictionary of fixed generator's input parameters. +# dataset: name of the dataset (hdf5 file). +# dataset_out: name for the output inverted dataset (hdf5 file). +# General parameters: +# batch_size: number of images inverted at the same time. +# inv_it: number of iterations to invert an image. +# inv_layer: 'latent' or name of the tensor of the custom layer to be inverted. +# lr: learning rate. +# decay_lr: exponential decay on the learning rate. +# decay_n: number of exponential decays on the learning rate. +# custom_grad_relu: replace relus with custom gradient. +# Logging: +# sample_size: number of images included in sampled images. +# save_progress: whether to save intermediate images during optimization. +# log_z_norm: log the norm of different sections of z. +# log_activation_layer: log the percentage of active neurons in this layer. +# Losses: +# mse: use the mean squared error on pixels for image comparison. +# features: use features extracted by a feature extractor for image comparison. +# feature_extractor_path: path to feature extractor module. +# feature_extractor_output: output name from feature extractor. +# likeli_loss: regularization loss on the log likelihood of encodings. +# norm_loss: regularization loss on the norm of encodings. +# dist_loss: whether to include a loss on the dist between g1(z) and enc. +# lambda_mse: coefficient for mse loss. +# lambda_feat: coefficient for features loss. +# lambda_reg: coefficient for regularization loss on latent. +# lambda_dist: coefficient for l1 regularization on delta. +# Latent: +# clipping: whether to clip encoding values after every update. +# stochastic_clipping: whether to consider stochastic clipping. +# clip: clipping bound. +# pretrained_latent: load pre trained fixed latent. +# fixed_z: do not train the latent vector. +# Initialization: +# init_gen_dist: initialize encodings from the generated distribution. +# init_lo: init min value. +# init_hi: init max value. +if len(sys.argv) < 2: + sys.exit('Must provide a configuration file.') +params = params.Params(sys.argv[1]) + +# -------------------------- +# Global directories. +# -------------------------- +LATENT_TAG = 'latent' if params.inv_layer == 'latent' else 'dense' +BATCH_SIZE = params.batch_size +SAMPLE_SIZE = params.sample_size +LOGS_DIR = os.path.join('inverses', params.tag, LATENT_TAG, 'logs') +SAMPLES_DIR = os.path.join('inverses', params.tag, LATENT_TAG, 'samples') +INVERSES_DIR = os.path.join('inverses', params.tag) +if not os.path.exists(LOGS_DIR): + os.makedirs(LOGS_DIR) +if not os.path.exists(SAMPLES_DIR): + os.makedirs(SAMPLES_DIR) +if not os.path.exists(INVERSES_DIR): + os.makedirs(INVERSES_DIR) + +# -------------------------- +# Util functions. +# -------------------------- +# One hot encoding for classes. +def one_hot(values): + return np.eye(N_CLASS)[values] + +# -------------------------- +# Logging. +# -------------------------- +summary_writer = tf.summary.FileWriter(LOGS_DIR) +def log_stats(name, val, it): + summary = tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=val)]) + summary_writer.add_summary(summary, it) + +# -------------------------- +# Load Graph. +# -------------------------- +generator = hub.Module(str(params.generator_path)) + +gen_signature = 'generator' +if 'generator' not in generator.get_signature_names(): + gen_signature = 'default' + +input_info = generator.get_input_info_dict(gen_signature) +COND_GAN = 'y' in input_info + +if COND_GAN: + Z_DIM = input_info['z'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + N_CLASS = input_info['y'].get_shape().as_list()[1] + label = tf.get_variable(name='label', dtype=tf.float32, + shape=[BATCH_SIZE, N_CLASS]) + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_in['y'] = label + gen_img = generator(gen_in, signature=gen_signature) +else: + Z_DIM = input_info['default'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + if (params.generator_fixed_inputs): + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_img = generator(gen_in, signature=gen_signature) + else: + gen_img = generator(latent, signature=gen_signature) + +# Convert generated image to channels_first. +gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) + +# Override intermediate layer. +if params.inv_layer == 'latent': + encoding = latent + ENC_SHAPE = [Z_DIM] +else: + layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer + gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) + ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] + encoding = tf.get_variable(name='encoding', dtype=tf.float32, + shape=[BATCH_SIZE,] + ENC_SHAPE) + tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) + +# Step counter. +inv_step = tf.get_variable('inv_step', initializer=0, trainable=False) + +# Define target image. +IMG_SHAPE = gen_img.get_shape().as_list()[1:] +target = tf.get_variable(name='target', dtype=tf.float32, # normally this is the real [0-255]image + shape=[BATCH_SIZE,] + IMG_SHAPE) +# target_img = (tf.cast(target, tf.float32) / 255.) * 2.0 - 1. # Norm to [-1, 1]. +target_img = target + +# Custom Gradient for Relus. +if params.custom_grad_relu: + grad_lambda = tf.train.exponential_decay(0.1, inv_step, params.inv_it / 5, + 0.1, staircase=False) + @tf.custom_gradient + def relu_custom_grad(x): + def grad(dy): + return tf.where(x >= 0, dy, + grad_lambda*tf.where(dy < 0, dy, tf.zeros_like(dy))) + return tf.nn.relu(x), grad + + gen_scope = 'module_apply_' + gen_signature + '/' + for op in tf.get_default_graph().get_operations(): + if 'Relu' in op.name and gen_scope in op.name: + assert len(op.inputs) == 1 + assert len(op.outputs) == 1 + new_out = relu_custom_grad(op.inputs[0]) + tf.contrib.graph_editor.swap_ts(op.outputs[0], new_out) + +# Operations to clip the values of the encodings. +if params.clipping or params.stochastic_clipping: + assert params.clip >= 0 + if params.stochastic_clipping: + new_enc = tf.where(tf.abs(latent) >= params.clip, + tf.random.uniform([BATCH_SIZE, Z_DIM], minval=-params.clip, + maxval=params.clip), latent) + else: + new_enc = tf.clip_by_value(latent, -params.clip, params.clip) + clip_latent = tf.assign(latent, new_enc) + +# Monitor relu's activation. +if params.log_activation_layer: + gen_scope = 'module_apply_' + gen_signature + '/' + activation_rate = 1.0 - tf.nn.zero_fraction(tf.get_default_graph()\ + .get_tensor_by_name(gen_scope + params.log_activation_layer)) + +# -------------------------- +# Reconstruction losses. +# -------------------------- +# Mse loss for image comparison. +if params.mse: + pix_square_diff = tf.square((target_img - gen_img) / 2.0) + mse_loss = tf.reduce_mean(pix_square_diff) + img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3]) +else: + mse_loss = tf.constant(0.0) + img_mse_err = tf.constant(0.0) + +# Use custom features for image comparison. +if params.features: + feature_extractor = hub.Module(str(params.feature_extractor_path)) + + # Convert images from range [-1, 1] channels_first to [0, 1] channels_last. + gen_img_1 = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1]) + target_img_1 = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1]) + + # Convert images to appropriate size for feature extraction. + height, width = hub.get_expected_image_size(feature_extractor) + gen_img_1 = tf.image.resize_images(gen_img_1, [height, width]) + target_img_1 = tf.image.resize_images(target_img_1, [height, width]) + + gen_feat_ex = feature_extractor(dict(images=gen_img_1), as_dict=True, signature='image_feature_vector') + target_feat_ex = feature_extractor(dict(images=target_img_1), as_dict=True, signature='image_feature_vector') + + # gen_feat = gen_feat_ex["InceptionV3/Mixed_7a"] + # target_feat = target_feat_ex["InceptionV3/Mixed_7a"] + # feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) + # feat_loss = tf.reduce_mean(feat_square_diff) * 0.334 + # img_feat_err = tf.reduce_mean(feat_square_diff, axis=1) * 0.334 + + # gen_feat = gen_feat_ex["InceptionV3/Mixed_7b"] + # target_feat = target_feat_ex["InceptionV3/Mixed_7b"] + # feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) + # feat_loss += tf.reduce_mean(feat_square_diff) * 0.333 + # img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.333 + + # gen_feat = gen_feat_ex["InceptionV3/Mixed_7c"] + # target_feat = target_feat_ex["InceptionV3/Mixed_7c"] + # feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) + # feat_loss += tf.reduce_mean(feat_square_diff) * 0.333 + # img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.333 + + # # gen_feat = gen_feat_ex["InceptionV3/Mixed_5a"] + # # target_feat = target_feat_ex["InceptionV3/Mixed_5a"] + # # feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) + # # feat_loss += tf.reduce_mean(feat_square_diff) * 0.16 + # # img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.16 + + # gen_feat = gen_feat_ex["InceptionV3/Mixed_7b"] + # target_feat = target_feat_ex["InceptionV3/Mixed_7b"] + # feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) + # feat_loss += tf.reduce_mean(feat_square_diff) * 0.33 + # img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) + + # # gen_feat = gen_feat_ex["InceptionV3/Mixed_7c"] + # # target_feat = target_feat_ex["InceptionV3/Mixed_7c"] + # # feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) + # # feat_loss += tf.reduce_mean(feat_square_diff) * 0.17 + # # img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.17 + + # conv1 1, conv1 2, conv3 2 and conv4 2 + gen_feat = gen_feat_ex["InceptionV3/Conv2d_1a_3x3"] + target_feat = target_feat_ex["InceptionV3/Conv2d_1a_3x3"] + feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) + feat_loss = tf.reduce_mean(feat_square_diff) * 0.15 + img_feat_err = tf.reduce_mean(feat_square_diff, axis=1) * 0.15 + + gen_feat = gen_feat_ex["InceptionV3/Conv2d_2a_3x3"] + target_feat = target_feat_ex["InceptionV3/Conv2d_2a_3x3"] + feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) + feat_loss += tf.reduce_mean(feat_square_diff) * 0.15 + img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.15 + + gen_feat = gen_feat_ex["InceptionV3/Conv2d_3b_1x1"] + target_feat = target_feat_ex["InceptionV3/Conv2d_3b_1x1"] + feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) + feat_loss += tf.reduce_mean(feat_square_diff) * 0.15 + img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.15 + + gen_feat = gen_feat_ex["InceptionV3/Conv2d_4a_3x3"] + target_feat = target_feat_ex["InceptionV3/Conv2d_4a_3x3"] + feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) + feat_loss += tf.reduce_mean(feat_square_diff) * 0.15 + img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.15 + + gen_feat = gen_feat_ex["InceptionV3/Mixed_7a"] + target_feat = target_feat_ex["InceptionV3/Mixed_7a"] + feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) + feat_loss += tf.reduce_mean(feat_square_diff) * 0.4 + img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.4 + +else: + feat_loss = tf.constant(0.0) + img_feat_err = tf.constant(0.0) + +# -------------------------- +# Regularization losses. +# -------------------------- +# Loss on the norm of the encoding. +if params.norm_loss: + dim = 20 + chi2_dist = tfp.distributions.Chi2(dim) + mode = dim - 2 + mode_log_prob = chi2_dist.log_prob(mode) + norm_loss = 0.0 + for i in range(int(Z_DIM / dim)): + squared_l2 = tf.reduce_sum(tf.square(latent[:,i*dim:(i+1)*dim]), axis=1) + over_mode = tf.nn.relu(squared_l2 - mode) + norm_loss -= tf.reduce_mean(chi2_dist.log_prob(mode + over_mode)) + norm_loss += mode_log_prob +else: + norm_loss = tf.constant(0.0) + +# Loss on the likelihood of the encoding. +if params.likeli_loss: + norm_dist = tfp.distributions.Normal(0.0, 1.0) + likeli_loss = - tf.reduce_mean(norm_dist.log_prob(latent)) + mode_log_prob = norm_dist.log_prob(0.0) + likeli_loss += mode_log_prob +else: + likeli_loss = tf.constant(0.0) + +# Regularization loss. +reg_loss = norm_loss + likeli_loss + +# Loss on the l1 distance between gen_encoding and inverted encoding. +if params.dist_loss: + dist_loss = tf.reduce_mean(tf.abs(encoding - gen_encoding)) +else: + dist_loss = tf.constant(0.0) + +# Per image reconstruction error. +img_rec_err = params.lambda_mse * img_mse_err\ + + params.lambda_feat * img_feat_err + +# Batch reconstruction error. +rec_loss = params.lambda_mse * mse_loss + params.lambda_feat * feat_loss + +# Total inversion loss. +inv_loss = rec_loss + params.lambda_reg * reg_loss\ + + params.lambda_dist * dist_loss + +# -------------------------- +# Optimizer. +# -------------------------- +if params.decay_lr: + lrate = tf.train.exponential_decay(params.lr, inv_step, + params.inv_it / params.decay_n, 0.1, staircase=True) +else: + lrate = tf.constant(params.lr) +trained_params = [encoding] if params.fixed_z else [latent, encoding] +optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999) +inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params, + global_step=inv_step) +reinit_optimizer = tf.variables_initializer(optimizer.variables()) + +# -------------------------- +# Noise source. +# -------------------------- +def noise_sampler(): + return np.random.normal(size=[BATCH_SIZE, Z_DIM]) + +def small_init(shape=[BATCH_SIZE, Z_DIM]): + return np.random.uniform(low=params.init_lo, high=params.init_hi, size=shape) + +# -------------------------- +# Dataset. +# -------------------------- +if params.dataset.endswith('.hdf5'): + in_file = h5py.File(params.dataset, 'r') + sample_images = in_file['xtrain'][()] + if COND_GAN: + sample_labels = in_file['ytrain'][()] + sample_fns = in_file['fn'][()] + NUM_IMGS = sample_images.shape[0] # number of images to be inverted. + print("Number of images: {}".format(NUM_IMGS)) + print("Batch size: {}".format(BATCH_SIZE)) + def sample_images_gen(): + for i in range(int(NUM_IMGS / BATCH_SIZE)): + i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE + if COND_GAN: + yield sample_images[i_1:i_2], sample_labels[i_1:i_2] + else: + yield sample_images[i_1:i_2], np.zeros(BATCH_SIZE) + image_gen = sample_images_gen() + if 'latent' in in_file: + sample_latents = in_file['latent'] + def sample_latent_gen(): + for i in range(int(NUM_IMGS / BATCH_SIZE)): + i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE + yield sample_latents[i_1:i_2] + latent_gen = sample_latent_gen() + if NUM_IMGS % BATCH_SIZE != 0: + REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE) + NUM_IMGS += REMAINDER + sample_images = np.append(sample_images, sample_images[-REMAINDER:,...], axis=0) + sample_labels = np.append(sample_labels, sample_labels[-REMAINDER:,...], axis=0) + sample_fns = np.append(sample_fns, sample_fns[-REMAINDER:], axis=0) + assert(NUM_IMGS % BATCH_SIZE == 0) +else: + sys.exit('Unknown dataset {}.'.format(params.dataset)) + +# -------------------------- +# Training. +# -------------------------- +# Start session. +sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) +sess.run(tf.global_variables_initializer()) +sess.run(tf.tables_initializer()) + +if params.max_batches > 0: + NUM_IMGS_TO_PROCESS = params.max_batches * BATCH_SIZE +else: + NUM_IMGS_TO_PROCESS = NUM_IMGS + +# Output file. +out_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'w') +out_images = out_file.create_dataset('xtrain', [NUM_IMGS_TO_PROCESS,] + IMG_SHAPE, dtype='float32') +out_enc = out_file.create_dataset('encoding', [NUM_IMGS_TO_PROCESS,] + ENC_SHAPE, dtype='float32') +out_lat = out_file.create_dataset('latent', [NUM_IMGS_TO_PROCESS, Z_DIM], dtype='float32') +out_fns = out_file.create_dataset('fn', [NUM_IMGS_TO_PROCESS], dtype=h5py.string_dtype()) +if COND_GAN: + out_labels = out_file.create_dataset('ytrain', (NUM_IMGS_TO_PROCESS, N_CLASS,), dtype='float32') +out_err = out_file.create_dataset('err', (NUM_IMGS_TO_PROCESS,)) + +out_fns[:] = sample_fns[:NUM_IMGS_TO_PROCESS] + +# Gradient descent w.r.t. generator's inputs. +it = 0 +out_pos = 0 +start_time = time.time() + +for image_batch, label_batch in image_gen: + + # Save target. + sess.run([ + target.assign(image_batch), + label.assign(label_batch), + latent.assign(next(latent_gen)), + inv_step.assign(0), + ]) + sess.run([ + encoding.assign(gen_encoding), + reinit_optimizer, + ]) + + # Main optimization loop. + print("Total iterations: {}".format(params.inv_it)) + for _ in range(params.inv_it): + + _inv_loss, _mse_loss, _feat_loss, _rec_loss, _reg_loss, _dist_loss,\ + _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, + rec_loss, reg_loss, dist_loss, lrate, inv_train_op]) + + if params.clipping or params.stochastic_clipping: + sess.run(clip_latent) + + # Save logs with training information. + if it % 500 == 0: + # Log losses. + etime = time.time() - start_time + print('It [{:8d}] time [{:5.1f}] total [{:.4f}] mse [{:.4f}] ' + 'feat [{:.4f}] rec [{:.4f}] reg [{:.4f}] dist [{:.4f}] ' + 'lr [{:.4f}]'.format(it, etime, _inv_loss, _mse_loss, + _feat_loss, _rec_loss, _reg_loss, _dist_loss, _lrate)) + + sys.stdout.flush() + + # Save target images and reconstructions. + if params.save_progress: + assert SAMPLE_SIZE <= BATCH_SIZE + gen_time = time.time() + gen_images = sess.run(gen_img) + print("Generation time: {:.1f}s".format(time.time() - gen_time)) + inv_batch = vs.interleave(vs.data2img(image_batch[BATCH_SIZE - SAMPLE_SIZE:]), + vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:])) + inv_batch = vs.grid_transform(inv_batch) + vs.save_image('{}/progress_{}_{:04d}.png'.format(SAMPLES_DIR, params.tag, int(it / 500)), inv_batch) + + it += 1 + + # Save images that are ready. + latent_batch, enc_batch, rec_err_batch = sess.run([latent, encoding, img_rec_err]) + out_lat[out_pos:out_pos+BATCH_SIZE] = latent_batch + out_enc[out_pos:out_pos+BATCH_SIZE] = enc_batch + out_images[out_pos:out_pos+BATCH_SIZE] = image_batch + out_labels[out_pos:out_pos+BATCH_SIZE] = label_batch + out_err[out_pos:out_pos+BATCH_SIZE] = rec_err_batch + out_pos += BATCH_SIZE + if params.max_batches > 0 and (out_pos / BATCH_SIZE) >= params.max_batches: + break + +print('Mean reconstruction error: {}'.format(np.mean(out_err))) +print('Stdev reconstruction error: {}'.format(np.std(out_err))) +print('End of inversion.') +out_file.close() +sess.close() |
