From fb70ab05768fa4a54358dc1f304b68bc7aff6dae Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Sun, 8 Dec 2019 21:43:30 +0100 Subject: inversion json files --- inversion/random_sample.py | 144 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 inversion/random_sample.py (limited to 'inversion/random_sample.py') diff --git a/inversion/random_sample.py b/inversion/random_sample.py new file mode 100644 index 0000000..61cac9c --- /dev/null +++ b/inversion/random_sample.py @@ -0,0 +1,144 @@ +# ------------------------------------------------------------------------------ +# Generate random samples of the generator and save the images to a hdf5 file. +# ------------------------------------------------------------------------------ + +import h5py +import numpy as np +import os +import params +import sys +import tensorflow as tf +import tensorflow_hub as hub +import time +import visualize as vs + +# -------------------------- +# Hyper-parameters. +# -------------------------- +# Expected parameters: +# generator_path: path to generator module. +# generator_fixed_inputs: dictionary of fixed generator's input parameters. +# dataset_out: name for the output created dataset (hdf5 file). +# General parameters: +# batch_size: number of images generated at the same time. +# random_label: choose random labels. +# num_imgs: number of instances to generate. +# custom_label: custom label to be fixed. +# Logging: +# sample_size: number of images included in sampled images. +if len(sys.argv) < 2: + sys.exit('Must provide a configuration file.') +params = params.Params(sys.argv[1]) + +# -------------------------- +# Hyper-parameters. +# -------------------------- +# General parameters. +BATCH_SIZE = params.batch_size +SAMPLE_SIZE = params.sample_size +assert SAMPLE_SIZE <= BATCH_SIZE +NUM_IMGS = params.num_imgs + +# -------------------------- +# Global directories. +# -------------------------- +SAMPLES_DIR = 'random_samples' +INVERSES_DIR = 'inverses' +if not os.path.exists(SAMPLES_DIR): + os.makedirs(SAMPLES_DIR) +if not os.path.exists(INVERSES_DIR): + os.makedirs(INVERSES_DIR) + +# -------------------------- +# Util functions. +# -------------------------- +def one_hot(values): + return np.eye(N_CLASS)[values] + +def label_sampler(size=1): + return np.random.random_integers(low=0, high=N_CLASS-1, size=size) + +# -------------------------- +# Load Graph. +# -------------------------- +generator = hub.Module(str(params.generator_path)) + +gen_signature = 'generator' +if 'generator' not in generator.get_signature_names(): + gen_signature = 'default' + +input_info = generator.get_input_info_dict(gen_signature) +COND_GAN = 'y' in input_info + +if COND_GAN: + Z_DIM = input_info['z'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + N_CLASS = input_info['y'].get_shape().as_list()[1] + label = tf.get_variable(name='label', dtype=tf.float32, + shape=[BATCH_SIZE, N_CLASS]) + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_in['y'] = label + gen_img = generator(gen_in, signature=gen_signature) +else: + Z_DIM = input_info['default'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + if (params.generator_fixed_inputs): + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_img = generator(gen_in, signature=gen_signature) + else: + gen_img = generator(latent, signature=gen_signature) + +# Convert generated image to channels_first. +gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) + +# Define image shape. +IMG_SHAPE = gen_img.get_shape().as_list()[1:] + +# -------------------------- +# Noise source. +# -------------------------- +def noise_sampler(): + return np.random.normal(size=[BATCH_SIZE, Z_DIM]) + +# -------------------------- +# Generation. +# -------------------------- +# Start session. +sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) +sess.run(tf.global_variables_initializer()) +sess.run(tf.tables_initializer()) + +# Output file. +out_file = h5py.File(os.path.join(INVERSES_DIR, params.dataset_out), 'w') +out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, + dtype='uint8') +if COND_GAN: + out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32') + +for i in range(0, NUM_IMGS, BATCH_SIZE): + n_encs = min(BATCH_SIZE, NUM_IMGS - i) + + if COND_GAN: + if params.random_label: + label_batch = label_sampler(BATCH_SIZE) + else: + label_batch = [params.custom_label]*BATCH_SIZE + sess.run(label.assign(one_hot(label_batch))) + + sess.run(latent.assign(noise_sampler())) + + gen_images = sess.run(gen_img) + + gen_images = vs.data2img(gen_images) + + out_images[i:i+n_encs] = gen_images[:n_encs] + if COND_GAN: + out_labels[i:i+n_encs] = label_batch[:n_encs] + + out_batch = vs.grid_transform(gen_images[:SAMPLE_SIZE]) + vs.save_image('{}/generated_{}.png'.format(SAMPLES_DIR, i), out_batch) + print('Saved samples for imgs: {}-{}.'.format(i,i+n_encs)) -- cgit v1.2.3-70-g09d2