summaryrefslogtreecommitdiff
path: root/inversion/random_sample.py
diff options
context:
space:
mode:
Diffstat (limited to 'inversion/random_sample.py')
-rw-r--r--inversion/random_sample.py144
1 files changed, 144 insertions, 0 deletions
diff --git a/inversion/random_sample.py b/inversion/random_sample.py
new file mode 100644
index 0000000..61cac9c
--- /dev/null
+++ b/inversion/random_sample.py
@@ -0,0 +1,144 @@
+# ------------------------------------------------------------------------------
+# Generate random samples of the generator and save the images to a hdf5 file.
+# ------------------------------------------------------------------------------
+
+import h5py
+import numpy as np
+import os
+import params
+import sys
+import tensorflow as tf
+import tensorflow_hub as hub
+import time
+import visualize as vs
+
+# --------------------------
+# Hyper-parameters.
+# --------------------------
+# Expected parameters:
+# generator_path: path to generator module.
+# generator_fixed_inputs: dictionary of fixed generator's input parameters.
+# dataset_out: name for the output created dataset (hdf5 file).
+# General parameters:
+# batch_size: number of images generated at the same time.
+# random_label: choose random labels.
+# num_imgs: number of instances to generate.
+# custom_label: custom label to be fixed.
+# Logging:
+# sample_size: number of images included in sampled images.
+if len(sys.argv) < 2:
+ sys.exit('Must provide a configuration file.')
+params = params.Params(sys.argv[1])
+
+# --------------------------
+# Hyper-parameters.
+# --------------------------
+# General parameters.
+BATCH_SIZE = params.batch_size
+SAMPLE_SIZE = params.sample_size
+assert SAMPLE_SIZE <= BATCH_SIZE
+NUM_IMGS = params.num_imgs
+
+# --------------------------
+# Global directories.
+# --------------------------
+SAMPLES_DIR = 'random_samples'
+INVERSES_DIR = 'inverses'
+if not os.path.exists(SAMPLES_DIR):
+ os.makedirs(SAMPLES_DIR)
+if not os.path.exists(INVERSES_DIR):
+ os.makedirs(INVERSES_DIR)
+
+# --------------------------
+# Util functions.
+# --------------------------
+def one_hot(values):
+ return np.eye(N_CLASS)[values]
+
+def label_sampler(size=1):
+ return np.random.random_integers(low=0, high=N_CLASS-1, size=size)
+
+# --------------------------
+# Load Graph.
+# --------------------------
+generator = hub.Module(str(params.generator_path))
+
+gen_signature = 'generator'
+if 'generator' not in generator.get_signature_names():
+ gen_signature = 'default'
+
+input_info = generator.get_input_info_dict(gen_signature)
+COND_GAN = 'y' in input_info
+
+if COND_GAN:
+ Z_DIM = input_info['z'].get_shape().as_list()[1]
+ latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+ N_CLASS = input_info['y'].get_shape().as_list()[1]
+ label = tf.get_variable(name='label', dtype=tf.float32,
+ shape=[BATCH_SIZE, N_CLASS])
+ gen_in = dict(params.generator_fixed_inputs)
+ gen_in['z'] = latent
+ gen_in['y'] = label
+ gen_img = generator(gen_in, signature=gen_signature)
+else:
+ Z_DIM = input_info['default'].get_shape().as_list()[1]
+ latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+ if (params.generator_fixed_inputs):
+ gen_in = dict(params.generator_fixed_inputs)
+ gen_in['z'] = latent
+ gen_img = generator(gen_in, signature=gen_signature)
+ else:
+ gen_img = generator(latent, signature=gen_signature)
+
+# Convert generated image to channels_first.
+gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
+
+# Define image shape.
+IMG_SHAPE = gen_img.get_shape().as_list()[1:]
+
+# --------------------------
+# Noise source.
+# --------------------------
+def noise_sampler():
+ return np.random.normal(size=[BATCH_SIZE, Z_DIM])
+
+# --------------------------
+# Generation.
+# --------------------------
+# Start session.
+sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+sess.run(tf.global_variables_initializer())
+sess.run(tf.tables_initializer())
+
+# Output file.
+out_file = h5py.File(os.path.join(INVERSES_DIR, params.dataset_out), 'w')
+out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE,
+ dtype='uint8')
+if COND_GAN:
+ out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32')
+
+for i in range(0, NUM_IMGS, BATCH_SIZE):
+ n_encs = min(BATCH_SIZE, NUM_IMGS - i)
+
+ if COND_GAN:
+ if params.random_label:
+ label_batch = label_sampler(BATCH_SIZE)
+ else:
+ label_batch = [params.custom_label]*BATCH_SIZE
+ sess.run(label.assign(one_hot(label_batch)))
+
+ sess.run(latent.assign(noise_sampler()))
+
+ gen_images = sess.run(gen_img)
+
+ gen_images = vs.data2img(gen_images)
+
+ out_images[i:i+n_encs] = gen_images[:n_encs]
+ if COND_GAN:
+ out_labels[i:i+n_encs] = label_batch[:n_encs]
+
+ out_batch = vs.grid_transform(gen_images[:SAMPLE_SIZE])
+ vs.save_image('{}/generated_{}.png'.format(SAMPLES_DIR, i), out_batch)
+ print('Saved samples for imgs: {}-{}.'.format(i,i+n_encs))