summaryrefslogtreecommitdiff
path: root/inversion/image_sample.py
diff options
context:
space:
mode:
Diffstat (limited to 'inversion/image_sample.py')
-rw-r--r--inversion/image_sample.py295
1 files changed, 295 insertions, 0 deletions
diff --git a/inversion/image_sample.py b/inversion/image_sample.py
new file mode 100644
index 0000000..83622a1
--- /dev/null
+++ b/inversion/image_sample.py
@@ -0,0 +1,295 @@
+# ------------------------------------------------------------------------------
+# Generate random samples of the generator and save the images to a hdf5 file.
+# ------------------------------------------------------------------------------
+
+import h5py
+import numpy as np
+import os
+import sys
+import tensorflow as tf
+import tensorflow_hub as hub
+import time
+import visualize as vs
+import argparse
+from glob import glob
+
+# --------------------------
+# Hyper-parameters.
+# --------------------------
+# Expected parameters:
+# generator_path: path to generator module.
+# generator_fixed_inputs: dictionary of fixed generator's input parameters.
+# dataset_out: name for the output created dataset (hdf5 file).
+# General parameters:
+# batch_size: number of images generated at the same time.
+# random_label: choose random labels.
+# num_imgs: number of instances to generate.
+# custom_label: custom label to be fixed.
+# Logging:
+# sample_size: number of images included in sampled images.
+if len(sys.argv) < 2:
+ sys.exit('Must provide a configuration file.')
+
+parser = argparse.ArgumentParser(description='Initialize the image search.')
+parser.add_argument('--input_dir', required=True, help='Input directory of images')
+parser.add_argument('--tag', default=str(int(time.time())), help='Tag this build')
+params = parser.parse_args()
+
+# --------------------------
+# Hyper-parameters.
+# --------------------------
+# General parameters.
+BATCH_SIZE = 20
+SAMPLE_SIZE = 20
+assert SAMPLE_SIZE <= BATCH_SIZE
+
+INVERSION_ITERATIONS = 1000
+
+# --------------------------
+# Global directories.
+# --------------------------
+DATASET_OUT = "{}_dataset.hdf5".format(params.tag)
+SAMPLES_DIR = './outputs/{}/samples'.format(params.tag)
+INVERSES_DIR = './outputs/{}/inverses'.format(params.tag)
+os.makedirs(SAMPLES_DIR, exist_ok=True)
+os.makedirs(INVERSES_DIR, exist_ok=True)
+
+# --------------------------
+# Load Graph.
+# --------------------------
+generator = hub.Module('https://tfhub.dev/deepmind/biggan-128/2')
+
+gen_signature = 'generator'
+if 'generator' not in generator.get_signature_names():
+ gen_signature = 'default'
+
+input_info = generator.get_input_info_dict(gen_signature)
+
+Z_DIM = input_info['z'].get_shape().as_list()[1]
+latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+N_CLASS = input_info['y'].get_shape().as_list()[1]
+label = tf.get_variable(name='label', dtype=tf.float32,
+ shape=[BATCH_SIZE, N_CLASS])
+gen_in = {}
+gen_in['truncation'] = 1.0
+gen_in['z'] = latent
+gen_in['y'] = label
+gen_img = generator(gen_in, signature=gen_signature)
+
+# Convert generated image to channels_first.
+gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
+
+# Override intermediate layer.
+encoding = latent
+ENC_SHAPE = [Z_DIM]
+
+# Step counter.
+inv_step = tf.get_variable('inv_step', initializer=0, trainable=False)
+
+# Define target image.
+IMG_SHAPE = gen_img.get_shape().as_list()[1:]
+target = tf.get_variable(name='target', dtype=tf.int32,
+ shape=[BATCH_SIZE,] + IMG_SHAPE)
+target_img = (tf.cast(target, tf.float32) / 255.) * 2.0 - 1. # Norm to [-1, 1].
+
+# Monitor relu's activation.
+gen_scope = 'module_apply_' + gen_signature + '/'
+activation_rate = 1.0 - tf.nn.zero_fraction(tf.get_default_graph()\
+ .get_tensor_by_name(gen_scope + params.log_activation_layer))
+
+# --------------------------
+# Reconstruction losses.
+# --------------------------
+# Mse loss for image comparison.
+pix_square_diff = tf.square((target_img - gen_img) / 2.0)
+mse_loss = tf.reduce_mean(pix_square_diff)
+img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3])
+
+# Use custom features for image comparison.
+feature_extractor = hub.Module("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1")
+
+# Convert images from range [-1, 1] channels_first to [0, 1] channels_last.
+gen_img_1 = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1])
+target_img_1 = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1])
+
+# Convert images to appropriate size for feature extraction.
+height, width = hub.get_expected_image_size(feature_extractor)
+gen_img_1 = tf.image.resize_images(gen_img_1, [height, width])
+target_img_1 = tf.image.resize_images(target_img_1, [height, width])
+
+gen_feat = feature_extractor(dict(images=gen_img_1), as_dict=True,
+ signature='image_feature_vector')[params.feature_extractor_output]
+target_feat = feature_extractor(dict(images=target_img_1), as_dict=True,
+ signature='image_feature_vector')[params.feature_extractor_output]
+feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat),
+ [BATCH_SIZE, -1])
+feat_loss = tf.reduce_mean(feat_square_diff)
+img_feat_err = tf.reduce_mean(feat_square_diff, axis=1)
+
+# --------------------------
+# Regularization losses.
+# --------------------------
+norm_dist = tfp.distributions.Normal(0.0, 1.0)
+likeli_loss = - tf.reduce_mean(norm_dist.log_prob(latent))
+mode_log_prob = norm_dist.log_prob(0.0)
+likeli_loss += mode_log_prob
+
+# Per image reconstruction error.
+img_rec_err = 1.0 * img_mse_err + 1.0 * img_feat_err
+
+# Batch reconstruction error.
+rec_loss = 1.0 * mse_loss + 1.0 * feat_loss
+
+# Total inversion loss.
+inv_loss = rec_loss + 0.1 * likeli_loss
+
+# --------------------------
+# Optimizer.
+# --------------------------
+lrate = tf.train.exponential_decay(0.1, inv_step,
+ INVERSION_ITERATIONS / 2, 0.1, staircase=True)
+trained_params = [latent, label]
+optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999)
+inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params,
+ global_step=inv_step)
+reinit_optimizer = tf.variables_initializer(optimizer.variables())
+
+# --------------------------
+# Noise source.
+# --------------------------
+def noise_sampler():
+ return np.random.normal(size=[BATCH_SIZE, Z_DIM])
+
+def label_init(shape=[BATCH_SIZE, N_CLASS]):
+ return np.random.uniform(low=0.00, high=0.01, size=shape)
+
+# --------------------------
+# Dataset.
+# --------------------------
+
+paths = glob(os.path.join(params.input_dir, '*.jpg')) + \
+ glob(os.path.join(params.input_dir, '*.jpeg')) + \
+ glob(os.path.join(params.input_dir, '*.png'))
+sample_images = [ vs.load_image(fn, 128) for fn in sorted(paths) ]
+ACTUAL_NUM_IMGS = sample_images.shape[0] # number of images to be inverted.
+print("Number of images: {}".format(ACTUAL_NUM_IMGS))
+NUM_IMGS = ACTUAL_NUM_IMGS
+
+# pad the image array to match the batch size
+while NUM_IMGS % BATCH_SIZE != 0:
+ sample_images += sample_images[-1]
+ NUM_IMGS += 1
+sample_images = np.array(sample_images)
+
+def sample_images_gen():
+ for i in range(int(NUM_IMGS / BATCH_SIZE)):
+ i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE
+ yield sample_images[i_1:i_2]
+image_gen = sample_images_gen()
+
+assert(NUM_IMGS % BATCH_SIZE == 0)
+
+# --------------------------
+# Generation.
+# --------------------------
+# Start session.
+sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+sess.run(tf.global_variables_initializer())
+sess.run(tf.tables_initializer())
+
+# Output file.
+out_file = h5py.File(os.path.join(INVERSES_DIR, DATASET_OUT), 'w')
+out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, dtype='uint8')
+out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32')
+
+# Gradient descent w.r.t. generator's inputs.
+it = 0
+out_pos = 0
+start_time = time.time()
+
+for image_batch in image_gen:
+ # Set target.
+ sess.run(target.assign(image_batch))
+
+ # Start with a random label
+ label_batch = label_sampler(BATCH_SIZE)
+ sess.run(label.assign(label_init()))
+
+ # Start with a random vector
+ sess.run(latent.assign(noise_sampler()))
+
+ # Init optimizer.
+ sess.run(inv_step.assign(0))
+ sess.run(reinit_optimizer)
+
+ # Main optimization loop.
+ for _ in range(params.inv_it):
+ _inv_loss, _mse_loss, _feat_loss, _rec_loss, _likeli_loss,\
+ _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss,
+ rec_loss, likeli_loss, lrate, inv_train_op])
+
+ # Every 100 iterations save logs with training information.
+ if it % 100 == 0:
+ # Log losses.
+ etime = time.time() - start_time
+
+ _act_rate = sess.run(activation_rate)
+ print('activation_rate={:.4f}'.format(_act_rate))
+ log_stats('activation rate', _act_rate, it)
+
+ sys.stdout.flush()
+
+ # Log tensorboard's statistics.
+ log_stats('total loss', _inv_loss, it)
+ log_stats('mse loss', _mse_loss, it)
+ log_stats('feat loss', _feat_loss, it)
+ log_stats('rec loss', _rec_loss, it)
+ log_stats('reg loss', _reg_loss, it)
+ log_stats('dist loss', _dist_loss, it)
+ log_stats('out pos', out_pos, it)
+ log_stats('lrate', _lrate, it)
+ summary_writer.flush()
+
+ gen_images = sess.run(gen_img)
+ inv_batch = vs.interleave(image_batch[BATCH_SIZE - SAMPLE_SIZE:],
+ vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:]))
+ inv_batch = vs.grid_transform(inv_batch)
+ vs.save_image('{}/progress_{}.png'.format(SAMPLES_DIR, it), inv_batch)
+
+ it += 1
+
+ # gen_images = sess.run(gen_img)
+ # gen_images = vs.data2img(gen_images)
+ label_batch = sess.run(label)
+ print(label_batch.shape)
+
+ out_images[i:i+BATCH_SIZE] = image_batch
+ out_labels[i:i+BATCH_SIZE] = label_batch
+
+ out_batch = vs.grid_transform(gen_images[:SAMPLE_SIZE])
+ vs.save_image('{}/generated_{}.png'.format(SAMPLES_DIR, i), out_batch)
+ print('Saved samples for imgs: {}-{}.'.format(i,i+BATCH_SIZE))
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+