# ------------------------------------------------------------------------------ # Generate random samples of the generator and save the images to a hdf5 file. # ------------------------------------------------------------------------------ import h5py import numpy as np import os import params import sys import tensorflow as tf import tensorflow_hub as hub import time import visualize as vs # -------------------------- # Hyper-parameters. # -------------------------- # Expected parameters: # generator_path: path to generator module. # generator_fixed_inputs: dictionary of fixed generator's input parameters. # dataset_out: name for the output created dataset (hdf5 file). # General parameters: # batch_size: number of images generated at the same time. # random_label: choose random labels. # num_imgs: number of instances to generate. # custom_label: custom label to be fixed. # Logging: # sample_size: number of images included in sampled images. if len(sys.argv) < 2: sys.exit('Must provide a configuration file.') params = params.Params(sys.argv[1]) # -------------------------- # Hyper-parameters. # -------------------------- # General parameters. BATCH_SIZE = params.batch_size SAMPLE_SIZE = params.sample_size assert SAMPLE_SIZE <= BATCH_SIZE NUM_IMGS = params.num_imgs # -------------------------- # Global directories. # -------------------------- SAMPLES_DIR = 'random_samples' INVERSES_DIR = 'inverses' if not os.path.exists(SAMPLES_DIR): os.makedirs(SAMPLES_DIR) if not os.path.exists(INVERSES_DIR): os.makedirs(INVERSES_DIR) # -------------------------- # Util functions. # -------------------------- def one_hot(values): return np.eye(N_CLASS)[values] def label_sampler(size=1): return np.random.random_integers(low=0, high=N_CLASS-1, size=size) # -------------------------- # Load Graph. # -------------------------- generator = hub.Module(str(params.generator_path)) gen_signature = 'generator' if 'generator' not in generator.get_signature_names(): gen_signature = 'default' input_info = generator.get_input_info_dict(gen_signature) COND_GAN = 'y' in input_info if COND_GAN: Z_DIM = input_info['z'].get_shape().as_list()[1] latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) N_CLASS = input_info['y'].get_shape().as_list()[1] label = tf.get_variable(name='label', dtype=tf.float32, shape=[BATCH_SIZE, N_CLASS]) gen_in = dict(params.generator_fixed_inputs) gen_in['z'] = latent gen_in['y'] = label gen_img = generator(gen_in, signature=gen_signature) else: Z_DIM = input_info['default'].get_shape().as_list()[1] latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) if (params.generator_fixed_inputs): gen_in = dict(params.generator_fixed_inputs) gen_in['z'] = latent gen_img = generator(gen_in, signature=gen_signature) else: gen_img = generator(latent, signature=gen_signature) # Convert generated image to channels_first. gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) # Define image shape. IMG_SHAPE = gen_img.get_shape().as_list()[1:] # -------------------------- # Noise source. # -------------------------- def noise_sampler(): return np.random.normal(size=[BATCH_SIZE, Z_DIM]) # -------------------------- # Generation. # -------------------------- # Start session. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) # Output file. out_file = h5py.File(os.path.join(INVERSES_DIR, params.dataset_out), 'w') out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, dtype='uint8') if COND_GAN: out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32') for i in range(0, NUM_IMGS, BATCH_SIZE): n_encs = min(BATCH_SIZE, NUM_IMGS - i) if COND_GAN: if params.random_label: label_batch = label_sampler(BATCH_SIZE) else: label_batch = [params.custom_label]*BATCH_SIZE sess.run(label.assign(one_hot(label_batch))) sess.run(latent.assign(noise_sampler())) gen_images = sess.run(gen_img) gen_images = vs.data2img(gen_images) out_images[i:i+n_encs] = gen_images[:n_encs] if COND_GAN: out_labels[i:i+n_encs] = label_batch[:n_encs] out_batch = vs.grid_transform(gen_images[:SAMPLE_SIZE]) vs.save_image('{}/generated_{}.png'.format(SAMPLES_DIR, i), out_batch) print('Saved samples for imgs: {}-{}.'.format(i,i+n_encs))