diff options
Diffstat (limited to 'inversion/image_sample.py')
| -rw-r--r-- | inversion/image_sample.py | 295 |
1 files changed, 295 insertions, 0 deletions
diff --git a/inversion/image_sample.py b/inversion/image_sample.py new file mode 100644 index 0000000..83622a1 --- /dev/null +++ b/inversion/image_sample.py @@ -0,0 +1,295 @@ +# ------------------------------------------------------------------------------ +# Generate random samples of the generator and save the images to a hdf5 file. +# ------------------------------------------------------------------------------ + +import h5py +import numpy as np +import os +import sys +import tensorflow as tf +import tensorflow_hub as hub +import time +import visualize as vs +import argparse +from glob import glob + +# -------------------------- +# Hyper-parameters. +# -------------------------- +# Expected parameters: +# generator_path: path to generator module. +# generator_fixed_inputs: dictionary of fixed generator's input parameters. +# dataset_out: name for the output created dataset (hdf5 file). +# General parameters: +# batch_size: number of images generated at the same time. +# random_label: choose random labels. +# num_imgs: number of instances to generate. +# custom_label: custom label to be fixed. +# Logging: +# sample_size: number of images included in sampled images. +if len(sys.argv) < 2: + sys.exit('Must provide a configuration file.') + +parser = argparse.ArgumentParser(description='Initialize the image search.') +parser.add_argument('--input_dir', required=True, help='Input directory of images') +parser.add_argument('--tag', default=str(int(time.time())), help='Tag this build') +params = parser.parse_args() + +# -------------------------- +# Hyper-parameters. +# -------------------------- +# General parameters. +BATCH_SIZE = 20 +SAMPLE_SIZE = 20 +assert SAMPLE_SIZE <= BATCH_SIZE + +INVERSION_ITERATIONS = 1000 + +# -------------------------- +# Global directories. +# -------------------------- +DATASET_OUT = "{}_dataset.hdf5".format(params.tag) +SAMPLES_DIR = './outputs/{}/samples'.format(params.tag) +INVERSES_DIR = './outputs/{}/inverses'.format(params.tag) +os.makedirs(SAMPLES_DIR, exist_ok=True) +os.makedirs(INVERSES_DIR, exist_ok=True) + +# -------------------------- +# Load Graph. +# -------------------------- +generator = hub.Module('https://tfhub.dev/deepmind/biggan-128/2') + +gen_signature = 'generator' +if 'generator' not in generator.get_signature_names(): + gen_signature = 'default' + +input_info = generator.get_input_info_dict(gen_signature) + +Z_DIM = input_info['z'].get_shape().as_list()[1] +latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) +N_CLASS = input_info['y'].get_shape().as_list()[1] +label = tf.get_variable(name='label', dtype=tf.float32, + shape=[BATCH_SIZE, N_CLASS]) +gen_in = {} +gen_in['truncation'] = 1.0 +gen_in['z'] = latent +gen_in['y'] = label +gen_img = generator(gen_in, signature=gen_signature) + +# Convert generated image to channels_first. +gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) + +# Override intermediate layer. +encoding = latent +ENC_SHAPE = [Z_DIM] + +# Step counter. +inv_step = tf.get_variable('inv_step', initializer=0, trainable=False) + +# Define target image. +IMG_SHAPE = gen_img.get_shape().as_list()[1:] +target = tf.get_variable(name='target', dtype=tf.int32, + shape=[BATCH_SIZE,] + IMG_SHAPE) +target_img = (tf.cast(target, tf.float32) / 255.) * 2.0 - 1. # Norm to [-1, 1]. + +# Monitor relu's activation. +gen_scope = 'module_apply_' + gen_signature + '/' +activation_rate = 1.0 - tf.nn.zero_fraction(tf.get_default_graph()\ + .get_tensor_by_name(gen_scope + params.log_activation_layer)) + +# -------------------------- +# Reconstruction losses. +# -------------------------- +# Mse loss for image comparison. +pix_square_diff = tf.square((target_img - gen_img) / 2.0) +mse_loss = tf.reduce_mean(pix_square_diff) +img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3]) + +# Use custom features for image comparison. +feature_extractor = hub.Module("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1") + +# Convert images from range [-1, 1] channels_first to [0, 1] channels_last. +gen_img_1 = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1]) +target_img_1 = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1]) + +# Convert images to appropriate size for feature extraction. +height, width = hub.get_expected_image_size(feature_extractor) +gen_img_1 = tf.image.resize_images(gen_img_1, [height, width]) +target_img_1 = tf.image.resize_images(target_img_1, [height, width]) + +gen_feat = feature_extractor(dict(images=gen_img_1), as_dict=True, + signature='image_feature_vector')[params.feature_extractor_output] +target_feat = feature_extractor(dict(images=target_img_1), as_dict=True, + signature='image_feature_vector')[params.feature_extractor_output] +feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), + [BATCH_SIZE, -1]) +feat_loss = tf.reduce_mean(feat_square_diff) +img_feat_err = tf.reduce_mean(feat_square_diff, axis=1) + +# -------------------------- +# Regularization losses. +# -------------------------- +norm_dist = tfp.distributions.Normal(0.0, 1.0) +likeli_loss = - tf.reduce_mean(norm_dist.log_prob(latent)) +mode_log_prob = norm_dist.log_prob(0.0) +likeli_loss += mode_log_prob + +# Per image reconstruction error. +img_rec_err = 1.0 * img_mse_err + 1.0 * img_feat_err + +# Batch reconstruction error. +rec_loss = 1.0 * mse_loss + 1.0 * feat_loss + +# Total inversion loss. +inv_loss = rec_loss + 0.1 * likeli_loss + +# -------------------------- +# Optimizer. +# -------------------------- +lrate = tf.train.exponential_decay(0.1, inv_step, + INVERSION_ITERATIONS / 2, 0.1, staircase=True) +trained_params = [latent, label] +optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999) +inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params, + global_step=inv_step) +reinit_optimizer = tf.variables_initializer(optimizer.variables()) + +# -------------------------- +# Noise source. +# -------------------------- +def noise_sampler(): + return np.random.normal(size=[BATCH_SIZE, Z_DIM]) + +def label_init(shape=[BATCH_SIZE, N_CLASS]): + return np.random.uniform(low=0.00, high=0.01, size=shape) + +# -------------------------- +# Dataset. +# -------------------------- + +paths = glob(os.path.join(params.input_dir, '*.jpg')) + \ + glob(os.path.join(params.input_dir, '*.jpeg')) + \ + glob(os.path.join(params.input_dir, '*.png')) +sample_images = [ vs.load_image(fn, 128) for fn in sorted(paths) ] +ACTUAL_NUM_IMGS = sample_images.shape[0] # number of images to be inverted. +print("Number of images: {}".format(ACTUAL_NUM_IMGS)) +NUM_IMGS = ACTUAL_NUM_IMGS + +# pad the image array to match the batch size +while NUM_IMGS % BATCH_SIZE != 0: + sample_images += sample_images[-1] + NUM_IMGS += 1 +sample_images = np.array(sample_images) + +def sample_images_gen(): + for i in range(int(NUM_IMGS / BATCH_SIZE)): + i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE + yield sample_images[i_1:i_2] +image_gen = sample_images_gen() + +assert(NUM_IMGS % BATCH_SIZE == 0) + +# -------------------------- +# Generation. +# -------------------------- +# Start session. +sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) +sess.run(tf.global_variables_initializer()) +sess.run(tf.tables_initializer()) + +# Output file. +out_file = h5py.File(os.path.join(INVERSES_DIR, DATASET_OUT), 'w') +out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, dtype='uint8') +out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32') + +# Gradient descent w.r.t. generator's inputs. +it = 0 +out_pos = 0 +start_time = time.time() + +for image_batch in image_gen: + # Set target. + sess.run(target.assign(image_batch)) + + # Start with a random label + label_batch = label_sampler(BATCH_SIZE) + sess.run(label.assign(label_init())) + + # Start with a random vector + sess.run(latent.assign(noise_sampler())) + + # Init optimizer. + sess.run(inv_step.assign(0)) + sess.run(reinit_optimizer) + + # Main optimization loop. + for _ in range(params.inv_it): + _inv_loss, _mse_loss, _feat_loss, _rec_loss, _likeli_loss,\ + _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, + rec_loss, likeli_loss, lrate, inv_train_op]) + + # Every 100 iterations save logs with training information. + if it % 100 == 0: + # Log losses. + etime = time.time() - start_time + + _act_rate = sess.run(activation_rate) + print('activation_rate={:.4f}'.format(_act_rate)) + log_stats('activation rate', _act_rate, it) + + sys.stdout.flush() + + # Log tensorboard's statistics. + log_stats('total loss', _inv_loss, it) + log_stats('mse loss', _mse_loss, it) + log_stats('feat loss', _feat_loss, it) + log_stats('rec loss', _rec_loss, it) + log_stats('reg loss', _reg_loss, it) + log_stats('dist loss', _dist_loss, it) + log_stats('out pos', out_pos, it) + log_stats('lrate', _lrate, it) + summary_writer.flush() + + gen_images = sess.run(gen_img) + inv_batch = vs.interleave(image_batch[BATCH_SIZE - SAMPLE_SIZE:], + vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:])) + inv_batch = vs.grid_transform(inv_batch) + vs.save_image('{}/progress_{}.png'.format(SAMPLES_DIR, it), inv_batch) + + it += 1 + + # gen_images = sess.run(gen_img) + # gen_images = vs.data2img(gen_images) + label_batch = sess.run(label) + print(label_batch.shape) + + out_images[i:i+BATCH_SIZE] = image_batch + out_labels[i:i+BATCH_SIZE] = label_batch + + out_batch = vs.grid_transform(gen_images[:SAMPLE_SIZE]) + vs.save_image('{}/generated_{}.png'.format(SAMPLES_DIR, i), out_batch) + print('Saved samples for imgs: {}-{}.'.format(i,i+BATCH_SIZE)) + + + + + + + + + + + + + + + + + + + + + + + |
