# ------------------------------------------------------------------------------ # Generate random samples of the generator and save the images to a hdf5 file. # ------------------------------------------------------------------------------ import h5py import numpy as np import os import sys import tensorflow as tf import tensorflow_hub as hub import tensorflow_probability as tfp import time import random import visualize as vs import argparse from glob import glob tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) # -------------------------- # Hyper-parameters. # -------------------------- # Expected parameters: # generator_path: path to generator module. # generator_fixed_inputs: dictionary of fixed generator's input parameters. # dataset_out: name for the output created dataset (hdf5 file). # General parameters: # batch_size: number of images generated at the same time. # random_label: choose random labels. # num_imgs: number of instances to generate. # custom_label: custom label to be fixed. # Logging: # sample_size: number of images included in sampled images. parser = argparse.ArgumentParser(description='Initialize the image search.') parser.add_argument('--input_dir', required=True, help='Input directory of images') parser.add_argument('--tag', default=str(int(time.time())), help='Tag this build') parser.add_argument('--iterations', type=int, default=1000, help='Number of iterations to find vector') params = parser.parse_args() # -------------------------- # Hyper-parameters. # -------------------------- # General parameters. BATCH_SIZE = 20 SAMPLE_SIZE = 20 assert SAMPLE_SIZE <= BATCH_SIZE # -------------------------- # Global directories. # -------------------------- DATASET_OUT = "{}_dataset.hdf5".format(params.tag) SAMPLES_DIR = './outputs/{}/samples'.format(params.tag) INVERSES_DIR = './outputs/{}/inverses'.format(params.tag) LOGS_DIR = './outputs/{}/logs'.format(params.tag) os.makedirs(SAMPLES_DIR, exist_ok=True) os.makedirs(INVERSES_DIR, exist_ok=True) os.makedirs(LOGS_DIR, exist_ok=True) # -------------------------- # Logging. # -------------------------- summary_writer = tf.summary.FileWriter(LOGS_DIR) def log_stats(name, val, it): summary = tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=val)]) summary_writer.add_summary(summary, it) # -------------------------- # Load Graph. # -------------------------- generator = hub.Module('https://tfhub.dev/deepmind/biggan-128/2') gen_signature = 'generator' if 'generator' not in generator.get_signature_names(): gen_signature = 'default' input_info = generator.get_input_info_dict(gen_signature) Z_DIM = input_info['z'].get_shape().as_list()[1] latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) N_CLASS = input_info['y'].get_shape().as_list()[1] label = tf.get_variable(name='label', dtype=tf.float32, shape=[BATCH_SIZE, N_CLASS]) gen_in = {} gen_in['truncation'] = 1.0 gen_in['z'] = latent gen_in['y'] = label gen_img = generator(gen_in, signature=gen_signature) # Convert generated image to channels_first. gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) # Override intermediate layer. encoding = latent ENC_SHAPE = [Z_DIM] # Step counter. inv_step = tf.get_variable('inv_step', initializer=0, trainable=False) # Define target image. IMG_SHAPE = gen_img.get_shape().as_list()[1:] target = tf.get_variable(name='target', dtype=tf.int32, shape=[BATCH_SIZE,] + IMG_SHAPE) target_img = (tf.cast(target, tf.float32) / 255.) * 2.0 - 1. # Norm to [-1, 1]. # Monitor relu's activation. gen_scope = 'module_apply_' + gen_signature + '/' activation_rate = 1.0 - tf.nn.zero_fraction(tf.get_default_graph()\ .get_tensor_by_name(gen_scope + "Generator_2/GBlock/Relu:0")) # -------------------------- # Reconstruction losses. # -------------------------- # Mse loss for image comparison. pix_square_diff = tf.square((target_img - gen_img) / 2.0) mse_loss = tf.reduce_mean(pix_square_diff) img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3]) # Use custom features for image comparison. feature_extractor = hub.Module("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1") # Convert images from range [-1, 1] channels_first to [0, 1] channels_last. gen_img_1 = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1]) target_img_1 = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1]) # Convert images to appropriate size for feature extraction. height, width = hub.get_expected_image_size(feature_extractor) gen_img_1 = tf.image.resize_images(gen_img_1, [height, width]) target_img_1 = tf.image.resize_images(target_img_1, [height, width]) gen_feat = feature_extractor(dict(images=gen_img_1), as_dict=True, signature='image_feature_vector')['InceptionV3/Mixed_7a'] target_feat = feature_extractor(dict(images=target_img_1), as_dict=True, signature='image_feature_vector')['InceptionV3/Mixed_7a'] feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) feat_loss = tf.reduce_mean(feat_square_diff) img_feat_err = tf.reduce_mean(feat_square_diff, axis=1) # -------------------------- # Regularization losses. # -------------------------- norm_dist = tfp.distributions.Normal(0.0, 1.0) likeli_loss = - tf.reduce_mean(norm_dist.log_prob(latent)) mode_log_prob = norm_dist.log_prob(0.0) likeli_loss += mode_log_prob # Per image reconstruction error. img_rec_err = 1.0 * img_mse_err + 1.0 * img_feat_err # Batch reconstruction error. rec_loss = 1.0 * mse_loss + 1.0 * feat_loss # Total inversion loss. inv_loss = rec_loss + 0.1 * likeli_loss # -------------------------- # Optimizer. # -------------------------- lrate = tf.train.exponential_decay(0.1, inv_step, params.iterations / 2, 0.1, staircase=True) trained_params = [latent, label] optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999) inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params, global_step=inv_step) reinit_optimizer = tf.variables_initializer(optimizer.variables()) # -------------------------- # Noise source. # -------------------------- def noise_sampler(): return np.random.normal(size=[BATCH_SIZE, Z_DIM]) #def label_sampler(shape=[BATCH_SIZE, N_CLASS]): # return np.random.uniform(low=0.00, high=0.01, size=shape) def label_sampler(shape=[BATCH_SIZE, N_CLASS]): num_classes = 2 label = np.zeros(shape) for i in range(shape[0]): for _ in range(random.randint(1, shape[1])): j = random.randint(0, shape[1]-1) label[i, j] = random.random() label[i] /= label[i].sum() return label # -------------------------- # Dataset. # -------------------------- paths = glob(os.path.join(params.input_dir, '*.jpg')) + \ glob(os.path.join(params.input_dir, '*.jpeg')) + \ glob(os.path.join(params.input_dir, '*.png')) sample_images = [ vs.load_image(fn, 128) for fn in sorted(paths) ] ACTUAL_NUM_IMGS = len(sample_images) # number of images to be inverted. print("Number of images: {}".format(ACTUAL_NUM_IMGS)) NUM_IMGS = ACTUAL_NUM_IMGS # pad the image array to match the batch size while NUM_IMGS % BATCH_SIZE != 0: sample_images += sample_images[-1] NUM_IMGS += 1 sample_images = np.array(sample_images) print(sample_images.shape) #sample_images = np.reshape(sample_images, (sample_images.shape[0],3,128,128,)) def sample_images_gen(): for i in range(int(NUM_IMGS / BATCH_SIZE)): i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE yield sample_images[i_1:i_2] image_gen = sample_images_gen() assert(NUM_IMGS % BATCH_SIZE == 0) # -------------------------- # Generation. # -------------------------- # Start session. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) # Output file. out_file = h5py.File(os.path.join(INVERSES_DIR, DATASET_OUT), 'w') out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, dtype='uint8') out_labels = out_file.create_dataset('ytrain', [NUM_IMGS, N_CLASS], dtype='float32') # Gradient descent w.r.t. generator's inputs. it = 0 out_pos = 0 count = 0 start_time = time.time() for image_batch in image_gen: # Set target. sess.run(target.assign(image_batch)) # Start with a random label #label_batch = label_sampler() sess.run(label.assign(label_sampler())) # Start with a random vector sess.run(latent.assign(noise_sampler())) # Init optimizer. sess.run(inv_step.assign(0)) sess.run(reinit_optimizer) # Main optimization loop. for _ in range(params.iterations): _inv_loss, _mse_loss, _feat_loss, _rec_loss, _likeli_loss,\ _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, rec_loss, likeli_loss, lrate, inv_train_op]) # Every 100 iterations save logs with training information. if it % 100 == 0: # Log losses. etime = time.time() - start_time _act_rate = sess.run(activation_rate) print('activation_rate={:.4f}'.format(_act_rate)) log_stats('activation rate', _act_rate, it) sys.stdout.flush() # Log tensorboard's statistics. log_stats('total loss', _inv_loss, it) log_stats('mse loss', _mse_loss, it) log_stats('feat loss', _feat_loss, it) log_stats('rec loss', _rec_loss, it) log_stats('likeli loss', _likeli_loss, it) log_stats('out pos', out_pos, it) log_stats('lrate', _lrate, it) summary_writer.flush() gen_images = sess.run(gen_img) #print(gen_images.shape) inv_batch = vs.interleave(vs.data2img(image_batch), vs.data2img(gen_images)) inv_batch = vs.grid_transform(inv_batch) #print(inv_batch.shape) vs.save_image('{}/initial_{:03d}_{:05d}.png'.format(SAMPLES_DIR, count, it), inv_batch) it += 1 # gen_images = sess.run(gen_img) # gen_images = vs.data2img(gen_images) label_batch = sess.run(label) print(label_batch.shape) out_images[count:count+BATCH_SIZE] = image_batch out_labels[count:count+BATCH_SIZE] = label_batch out_batch = vs.grid_transform(vs.data2img(gen_images)) vs.save_image('{}/initial_generated_{}.png'.format(SAMPLES_DIR, count), out_batch) print('Saved samples for imgs: {}-{}.'.format(count,count+BATCH_SIZE)) count += BATCH_SIZE