# ------------------------------------------------------------------------------ # Generate random samples of the generator and save the images to a hdf5 file. # ------------------------------------------------------------------------------ import h5py import numpy as np import os import sys import tensorflow as tf import tensorflow_hub as hub import time import visualize as vs import argparse from glob import glob # -------------------------- # Hyper-parameters. # -------------------------- # Expected parameters: # generator_path: path to generator module. # generator_fixed_inputs: dictionary of fixed generator's input parameters. # dataset_out: name for the output created dataset (hdf5 file). # General parameters: # batch_size: number of images generated at the same time. # random_label: choose random labels. # num_imgs: number of instances to generate. # custom_label: custom label to be fixed. # Logging: # sample_size: number of images included in sampled images. if len(sys.argv) < 2: sys.exit('Must provide a configuration file.') parser = argparse.ArgumentParser(description='Initialize the image search.') parser.add_argument('--input_dir', required=True, help='Input directory of images') parser.add_argument('--tag', default=str(int(time.time())), help='Tag this build') params = parser.parse_args() # -------------------------- # Hyper-parameters. # -------------------------- # General parameters. BATCH_SIZE = 20 SAMPLE_SIZE = 20 assert SAMPLE_SIZE <= BATCH_SIZE INVERSION_ITERATIONS = 1000 # -------------------------- # Global directories. # -------------------------- DATASET_OUT = "{}_dataset.hdf5".format(params.tag) SAMPLES_DIR = './outputs/{}/samples'.format(params.tag) INVERSES_DIR = './outputs/{}/inverses'.format(params.tag) os.makedirs(SAMPLES_DIR, exist_ok=True) os.makedirs(INVERSES_DIR, exist_ok=True) # -------------------------- # Load Graph. # -------------------------- generator = hub.Module('https://tfhub.dev/deepmind/biggan-128/2') gen_signature = 'generator' if 'generator' not in generator.get_signature_names(): gen_signature = 'default' input_info = generator.get_input_info_dict(gen_signature) Z_DIM = input_info['z'].get_shape().as_list()[1] latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) N_CLASS = input_info['y'].get_shape().as_list()[1] label = tf.get_variable(name='label', dtype=tf.float32, shape=[BATCH_SIZE, N_CLASS]) gen_in = {} gen_in['truncation'] = 1.0 gen_in['z'] = latent gen_in['y'] = label gen_img = generator(gen_in, signature=gen_signature) # Convert generated image to channels_first. gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) # Override intermediate layer. encoding = latent ENC_SHAPE = [Z_DIM] # Step counter. inv_step = tf.get_variable('inv_step', initializer=0, trainable=False) # Define target image. IMG_SHAPE = gen_img.get_shape().as_list()[1:] target = tf.get_variable(name='target', dtype=tf.int32, shape=[BATCH_SIZE,] + IMG_SHAPE) target_img = (tf.cast(target, tf.float32) / 255.) * 2.0 - 1. # Norm to [-1, 1]. # Monitor relu's activation. gen_scope = 'module_apply_' + gen_signature + '/' activation_rate = 1.0 - tf.nn.zero_fraction(tf.get_default_graph()\ .get_tensor_by_name(gen_scope + params.log_activation_layer)) # -------------------------- # Reconstruction losses. # -------------------------- # Mse loss for image comparison. pix_square_diff = tf.square((target_img - gen_img) / 2.0) mse_loss = tf.reduce_mean(pix_square_diff) img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3]) # Use custom features for image comparison. feature_extractor = hub.Module("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1") # Convert images from range [-1, 1] channels_first to [0, 1] channels_last. gen_img_1 = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1]) target_img_1 = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1]) # Convert images to appropriate size for feature extraction. height, width = hub.get_expected_image_size(feature_extractor) gen_img_1 = tf.image.resize_images(gen_img_1, [height, width]) target_img_1 = tf.image.resize_images(target_img_1, [height, width]) gen_feat = feature_extractor(dict(images=gen_img_1), as_dict=True, signature='image_feature_vector')[params.feature_extractor_output] target_feat = feature_extractor(dict(images=target_img_1), as_dict=True, signature='image_feature_vector')[params.feature_extractor_output] feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) feat_loss = tf.reduce_mean(feat_square_diff) img_feat_err = tf.reduce_mean(feat_square_diff, axis=1) # -------------------------- # Regularization losses. # -------------------------- norm_dist = tfp.distributions.Normal(0.0, 1.0) likeli_loss = - tf.reduce_mean(norm_dist.log_prob(latent)) mode_log_prob = norm_dist.log_prob(0.0) likeli_loss += mode_log_prob # Per image reconstruction error. img_rec_err = 1.0 * img_mse_err + 1.0 * img_feat_err # Batch reconstruction error. rec_loss = 1.0 * mse_loss + 1.0 * feat_loss # Total inversion loss. inv_loss = rec_loss + 0.1 * likeli_loss # -------------------------- # Optimizer. # -------------------------- lrate = tf.train.exponential_decay(0.1, inv_step, INVERSION_ITERATIONS / 2, 0.1, staircase=True) trained_params = [latent, label] optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999) inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params, global_step=inv_step) reinit_optimizer = tf.variables_initializer(optimizer.variables()) # -------------------------- # Noise source. # -------------------------- def noise_sampler(): return np.random.normal(size=[BATCH_SIZE, Z_DIM]) def label_init(shape=[BATCH_SIZE, N_CLASS]): return np.random.uniform(low=0.00, high=0.01, size=shape) # -------------------------- # Dataset. # -------------------------- paths = glob(os.path.join(params.input_dir, '*.jpg')) + \ glob(os.path.join(params.input_dir, '*.jpeg')) + \ glob(os.path.join(params.input_dir, '*.png')) sample_images = [ vs.load_image(fn, 128) for fn in sorted(paths) ] ACTUAL_NUM_IMGS = sample_images.shape[0] # number of images to be inverted. print("Number of images: {}".format(ACTUAL_NUM_IMGS)) NUM_IMGS = ACTUAL_NUM_IMGS # pad the image array to match the batch size while NUM_IMGS % BATCH_SIZE != 0: sample_images += sample_images[-1] NUM_IMGS += 1 sample_images = np.array(sample_images) def sample_images_gen(): for i in range(int(NUM_IMGS / BATCH_SIZE)): i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE yield sample_images[i_1:i_2] image_gen = sample_images_gen() assert(NUM_IMGS % BATCH_SIZE == 0) # -------------------------- # Generation. # -------------------------- # Start session. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) # Output file. out_file = h5py.File(os.path.join(INVERSES_DIR, DATASET_OUT), 'w') out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, dtype='uint8') out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32') # Gradient descent w.r.t. generator's inputs. it = 0 out_pos = 0 start_time = time.time() for image_batch in image_gen: # Set target. sess.run(target.assign(image_batch)) # Start with a random label label_batch = label_sampler(BATCH_SIZE) sess.run(label.assign(label_init())) # Start with a random vector sess.run(latent.assign(noise_sampler())) # Init optimizer. sess.run(inv_step.assign(0)) sess.run(reinit_optimizer) # Main optimization loop. for _ in range(params.inv_it): _inv_loss, _mse_loss, _feat_loss, _rec_loss, _likeli_loss,\ _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, rec_loss, likeli_loss, lrate, inv_train_op]) # Every 100 iterations save logs with training information. if it % 100 == 0: # Log losses. etime = time.time() - start_time _act_rate = sess.run(activation_rate) print('activation_rate={:.4f}'.format(_act_rate)) log_stats('activation rate', _act_rate, it) sys.stdout.flush() # Log tensorboard's statistics. log_stats('total loss', _inv_loss, it) log_stats('mse loss', _mse_loss, it) log_stats('feat loss', _feat_loss, it) log_stats('rec loss', _rec_loss, it) log_stats('reg loss', _reg_loss, it) log_stats('dist loss', _dist_loss, it) log_stats('out pos', out_pos, it) log_stats('lrate', _lrate, it) summary_writer.flush() gen_images = sess.run(gen_img) inv_batch = vs.interleave(image_batch[BATCH_SIZE - SAMPLE_SIZE:], vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:])) inv_batch = vs.grid_transform(inv_batch) vs.save_image('{}/progress_{}.png'.format(SAMPLES_DIR, it), inv_batch) it += 1 # gen_images = sess.run(gen_img) # gen_images = vs.data2img(gen_images) label_batch = sess.run(label) print(label_batch.shape) out_images[i:i+BATCH_SIZE] = image_batch out_labels[i:i+BATCH_SIZE] = label_batch out_batch = vs.grid_transform(gen_images[:SAMPLE_SIZE]) vs.save_image('{}/generated_{}.png'.format(SAMPLES_DIR, i), out_batch) print('Saved samples for imgs: {}-{}.'.format(i,i+BATCH_SIZE))