# ------------------------------------------------------------------------------ # Implementation of the inverse of Generator by Gradient descent w.r.t. # generator's inputs, for many intermediate layers. # ------------------------------------------------------------------------------ import glob import h5py import itertools import numpy as np import os import params import PIL import scipy import sys import tensorflow as tf import tensorflow_probability as tfp import tensorflow_hub as hub import time import visualize as vs # -------------------------- # Hyper-parameters. # -------------------------- # Expected parameters: # generator_path: path to generator module. # generator_fixed_inputs: dictionary of fixed generator's input parameters. # dataset: name of the dataset (hdf5 file). # dataset_out: name for the output inverted dataset (hdf5 file). # General parameters: # batch_size: number of images inverted at the same time. # inv_it: number of iterations to invert an image. # inv_layer: 'latent' or name of the tensor of the custom layer to be inverted. # lr: learning rate. # decay_lr: exponential decay on the learning rate. # decay_n: number of exponential decays on the learning rate. # custom_grad_relu: replace relus with custom gradient. # Logging: # sample_size: number of images included in sampled images. # save_progress: whether to save intermediate images during optimization. # log_z_norm: log the norm of different sections of z. # log_activation_layer: log the percentage of active neurons in this layer. # Losses: # mse: use the mean squared error on pixels for image comparison. # features: use features extracted by a feature extractor for image comparison. # feature_extractor_path: path to feature extractor module. # feature_extractor_output: output name from feature extractor. # likeli_loss: regularization loss on the log likelihood of encodings. # norm_loss: regularization loss on the norm of encodings. # dist_loss: whether to include a loss on the dist between g1(z) and enc. # lambda_mse: coefficient for mse loss. # lambda_feat: coefficient for features loss. # lambda_reg: coefficient for regularization loss on latent. # lambda_dist: coefficient for l1 regularization on delta. # Latent: # clipping: whether to clip encoding values after every update. # stochastic_clipping: whether to consider stochastic clipping. # clip: clipping bound. # pretrained_latent: load pre trained fixed latent. # fixed_z: do not train the latent vector. # Initialization: # init_gen_dist: initialize encodings from the generated distribution. # init_lo: init min value. # init_hi: init max value. if len(sys.argv) < 2: sys.exit('Must provide a configuration file.') params = params.Params(sys.argv[1]) # -------------------------- # Global directories. # -------------------------- LATENT_TAG = 'latent' if params.inv_layer == 'latent' else 'dense' BATCH_SIZE = params.batch_size SAMPLE_SIZE = params.sample_size LOGS_DIR = os.path.join('inverses', params.tag, LATENT_TAG, 'logs') SAMPLES_DIR = os.path.join('inverses', params.tag, LATENT_TAG, 'samples') INVERSES_DIR = os.path.join('inverses', params.tag) if not os.path.exists(LOGS_DIR): os.makedirs(LOGS_DIR) if not os.path.exists(SAMPLES_DIR): os.makedirs(SAMPLES_DIR) if not os.path.exists(INVERSES_DIR): os.makedirs(INVERSES_DIR) # -------------------------- # Util functions. # -------------------------- # One hot encoding for classes. def one_hot(values): return np.eye(N_CLASS)[values] # -------------------------- # Logging. # -------------------------- summary_writer = tf.summary.FileWriter(LOGS_DIR) def log_stats(name, val, it): summary = tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=val)]) summary_writer.add_summary(summary, it) # -------------------------- # Load Graph. # -------------------------- generator = hub.Module(str(params.generator_path)) gen_signature = 'generator' if 'generator' not in generator.get_signature_names(): gen_signature = 'default' input_info = generator.get_input_info_dict(gen_signature) COND_GAN = 'y' in input_info if COND_GAN: Z_DIM = input_info['z'].get_shape().as_list()[1] latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) N_CLASS = input_info['y'].get_shape().as_list()[1] label = tf.get_variable(name='label', dtype=tf.float32, shape=[BATCH_SIZE, N_CLASS]) gen_in = dict(params.generator_fixed_inputs) gen_in['z'] = latent gen_in['y'] = label gen_img = generator(gen_in, signature=gen_signature) else: Z_DIM = input_info['default'].get_shape().as_list()[1] latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) if (params.generator_fixed_inputs): gen_in = dict(params.generator_fixed_inputs) gen_in['z'] = latent gen_img = generator(gen_in, signature=gen_signature) else: gen_img = generator(latent, signature=gen_signature) # Convert generated image to channels_first. gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) # Override intermediate layer. if params.inv_layer == 'latent': encoding = latent ENC_SHAPE = [Z_DIM] else: layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] encoding = tf.get_variable(name='encoding', dtype=tf.float32, shape=[BATCH_SIZE,] + ENC_SHAPE) tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) # Step counter. inv_step = tf.get_variable('inv_step', initializer=0, trainable=False) # Define target image. IMG_SHAPE = gen_img.get_shape().as_list()[1:] target = tf.get_variable(name='target', dtype=tf.float32, # normally this is the real [0-255]image shape=[BATCH_SIZE,] + IMG_SHAPE) # target_img = (tf.cast(target, tf.float32) / 255.) * 2.0 - 1. # Norm to [-1, 1]. target_img = target # Custom Gradient for Relus. if params.custom_grad_relu: grad_lambda = tf.train.exponential_decay(0.1, inv_step, params.inv_it / 5, 0.1, staircase=False) @tf.custom_gradient def relu_custom_grad(x): def grad(dy): return tf.where(x >= 0, dy, grad_lambda*tf.where(dy < 0, dy, tf.zeros_like(dy))) return tf.nn.relu(x), grad gen_scope = 'module_apply_' + gen_signature + '/' for op in tf.get_default_graph().get_operations(): if 'Relu' in op.name and gen_scope in op.name: assert len(op.inputs) == 1 assert len(op.outputs) == 1 new_out = relu_custom_grad(op.inputs[0]) tf.contrib.graph_editor.swap_ts(op.outputs[0], new_out) # Operations to clip the values of the encodings. if params.clipping or params.stochastic_clipping: assert params.clip >= 0 if params.stochastic_clipping: new_enc = tf.where(tf.abs(latent) >= params.clip, tf.random.uniform([BATCH_SIZE, Z_DIM], minval=-params.clip, maxval=params.clip), latent) else: new_enc = tf.clip_by_value(latent, -params.clip, params.clip) clip_latent = tf.assign(latent, new_enc) # Monitor relu's activation. if params.log_activation_layer: gen_scope = 'module_apply_' + gen_signature + '/' activation_rate = 1.0 - tf.nn.zero_fraction(tf.get_default_graph()\ .get_tensor_by_name(gen_scope + params.log_activation_layer)) # -------------------------- # Reconstruction losses. # -------------------------- # Mse loss for image comparison. if params.mse: pix_square_diff = tf.square((target_img - gen_img) / 2.0) mse_loss = tf.reduce_mean(pix_square_diff) img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3]) else: mse_loss = tf.constant(0.0) img_mse_err = tf.constant(0.0) # Use custom features for image comparison. if params.features: feature_extractor = hub.Module(str(params.feature_extractor_path)) # Convert images from range [-1, 1] channels_first to [0, 1] channels_last. gen_img_1 = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1]) target_img_1 = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1]) # Convert images to appropriate size for feature extraction. height, width = hub.get_expected_image_size(feature_extractor) gen_img_1 = tf.image.resize_images(gen_img_1, [height, width]) target_img_1 = tf.image.resize_images(target_img_1, [height, width]) gen_feat = feature_extractor(dict(images=gen_img_1), as_dict=True, signature='image_feature_vector')[params.feature_extractor_output] target_feat = feature_extractor(dict(images=target_img_1), as_dict=True, signature='image_feature_vector')[params.feature_extractor_output] feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) feat_loss = tf.reduce_mean(feat_square_diff) img_feat_err = tf.reduce_mean(feat_square_diff, axis=1) else: feat_loss = tf.constant(0.0) img_feat_err = tf.constant(0.0) # -------------------------- # Regularization losses. # -------------------------- # Loss on the norm of the encoding. if params.norm_loss: dim = 20 chi2_dist = tfp.distributions.Chi2(dim) mode = dim - 2 mode_log_prob = chi2_dist.log_prob(mode) norm_loss = 0.0 for i in range(int(Z_DIM / dim)): squared_l2 = tf.reduce_sum(tf.square(latent[:,i*dim:(i+1)*dim]), axis=1) over_mode = tf.nn.relu(squared_l2 - mode) norm_loss -= tf.reduce_mean(chi2_dist.log_prob(mode + over_mode)) norm_loss += mode_log_prob else: norm_loss = tf.constant(0.0) # Loss on the likelihood of the encoding. if params.likeli_loss: norm_dist = tfp.distributions.Normal(0.0, 1.0) likeli_loss = - tf.reduce_mean(norm_dist.log_prob(latent)) mode_log_prob = norm_dist.log_prob(0.0) likeli_loss += mode_log_prob else: likeli_loss = tf.constant(0.0) # Regularization loss. reg_loss = norm_loss + likeli_loss # Loss on the l1 distance between gen_encoding and inverted encoding. if params.dist_loss: dist_loss = tf.reduce_mean(tf.abs(encoding - gen_encoding)) else: dist_loss = tf.constant(0.0) # Per image reconstruction error. img_rec_err = params.lambda_mse * img_mse_err\ + params.lambda_feat * img_feat_err # Batch reconstruction error. rec_loss = params.lambda_mse * mse_loss + params.lambda_feat * feat_loss # Total inversion loss. inv_loss = rec_loss + params.lambda_reg * reg_loss\ + params.lambda_dist * dist_loss # -------------------------- # Optimizer. # -------------------------- if params.decay_lr: lrate = tf.train.exponential_decay(params.lr, inv_step, params.inv_it / params.decay_n, 0.1, staircase=True) else: lrate = tf.constant(params.lr) trained_params = [encoding] if params.fixed_z else [latent, encoding] optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999) inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params, global_step=inv_step) reinit_optimizer = tf.variables_initializer(optimizer.variables()) # -------------------------- # Noise source. # -------------------------- def noise_sampler(): return np.random.normal(size=[BATCH_SIZE, Z_DIM]) def small_init(shape=[BATCH_SIZE, Z_DIM]): return np.random.uniform(low=params.init_lo, high=params.init_hi, size=shape) # -------------------------- # Dataset. # -------------------------- if params.dataset.endswith('.hdf5'): in_file = h5py.File(params.dataset, 'r') sample_images = in_file['xtrain'][()] if COND_GAN: sample_labels = in_file['ytrain'][()] sample_fns = in_file['fn'][()] NUM_IMGS = sample_images.shape[0] # number of images to be inverted. print("Number of images: {}".format(NUM_IMGS)) print("Batch size: {}".format(BATCH_SIZE)) def sample_images_gen(): for i in range(int(NUM_IMGS / BATCH_SIZE)): i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE if COND_GAN: yield sample_images[i_1:i_2], sample_labels[i_1:i_2] else: yield sample_images[i_1:i_2], np.zeros(BATCH_SIZE) image_gen = sample_images_gen() if 'latent' in in_file: sample_latents = in_file['latent'] def sample_latent_gen(): for i in range(int(NUM_IMGS / BATCH_SIZE)): i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE yield sample_latents[i_1:i_2] latent_gen = sample_latent_gen() if NUM_IMGS % BATCH_SIZE != 0: REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE) NUM_IMGS += REMAINDER sample_images += sample_images[-REMAINDER:] sample_labels += sample_labels[-REMAINDER:] sample_fns += sample_fns[-REMAINDER:] assert(NUM_IMGS % BATCH_SIZE == 0) else: sys.exit('Unknown dataset {}.'.format(params.dataset)) # -------------------------- # Training. # -------------------------- # Start session. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) # Output file. out_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'w') out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, dtype='float32') out_enc = out_file.create_dataset('encoding', [NUM_IMGS,] + ENC_SHAPE, dtype='float32') out_lat = out_file.create_dataset('latent', [NUM_IMGS, Z_DIM], dtype='float32') out_fns = out_file.create_dataset('fn', [NUM_IMGS], dtype=h5py.string_dtype()) if COND_GAN: out_labels = out_file.create_dataset('ytrain', (NUM_IMGS, N_CLASS,), dtype='float32') out_err = out_file.create_dataset('err', (NUM_IMGS,)) out_fns[:] = sample_fns # Gradient descent w.r.t. generator's inputs. it = 0 out_pos = 0 start_time = time.time() for image_batch, label_batch in image_gen: # Save target. sess.run(target.assign(image_batch)) if COND_GAN: sess.run(label.assign(label_batch)) # Initialize encodings to random values. if params.pre_trained_latent: sess.run(latent.assign(next(latent_gen))) if params.inv_layer != 'latent': sess.run(encoding.assign(gen_encoding)) else: if params.init_gen_dist: sess.run(latent.assign(noise_sampler())) if params.inv_layer != 'latent': sess.run(encoding.assign(gen_encoding)) else: sess.run(latent.assign(small_init())) if params.inv_layer != 'latent': sess.run(encoding.assign(small_init(shape=[BATCH_SIZE,] + ENC_SHAPE))) # Init optimizer. sess.run(inv_step.assign(0)) sess.run(reinit_optimizer) # Main optimization loop. print("Total iterations: {}".format(params.inv_it)) for _ in range(params.inv_it): _inv_loss, _mse_loss, _feat_loss, _rec_loss, _reg_loss, _dist_loss,\ _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, rec_loss, reg_loss, dist_loss, lrate, inv_train_op]) if params.clipping or params.stochastic_clipping: sess.run(clip_latent) # Every 100 iterations save logs with training information. if it < 100 or it % 100 == 0: # Log losses. etime = time.time() - start_time print('It [{:8d}] time [{:5.1f}] total [{:.4f}] mse [{:.4f}] ' 'feat [{:.4f}] rec [{:.4f}] reg [{:.4f}] dist [{:.4f}] ' 'lr [{:.4f}]'.format(it, etime, _inv_loss, _mse_loss, _feat_loss, _rec_loss, _reg_loss, _dist_loss, _lrate)) if params.log_z_norm: _lat = sess.run(latent) dim = 20 if Z_DIM == 120 else Z_DIM for i in range(int(Z_DIM/dim)): _subset = _lat[:,i*dim:(i+1)*dim] print('section {:1d}: norm={:.4f} (exp={:.4f}) min={:.4f} max={:.4f}'\ .format(i, np.mean(np.linalg.norm(_subset, axis=1)), np.sqrt(dim-2), np.min(_subset), np.max(_subset))) if params.log_activation_layer: _act_rate = sess.run(activation_rate) print('activation_rate={:.4f}'.format(_act_rate)) log_stats('activation rate', _act_rate, it) sys.stdout.flush() # Log tensorboard's statistics. log_stats('total loss', _inv_loss, it) log_stats('mse loss', _mse_loss, it) log_stats('feat loss', _feat_loss, it) log_stats('rec loss', _rec_loss, it) log_stats('reg loss', _reg_loss, it) log_stats('dist loss', _dist_loss, it) log_stats('out pos', out_pos, it) log_stats('lrate', _lrate, it) summary_writer.flush() # Save target images and reconstructions. if params.save_progress: assert SAMPLE_SIZE <= BATCH_SIZE gen_images = sess.run(gen_img) inv_batch = vs.interleave(vs.data2img(image_batch[BATCH_SIZE - SAMPLE_SIZE:]), vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:])) inv_batch = vs.grid_transform(inv_batch) vs.save_image('{}/progress_{}_{}.png'.format(SAMPLES_DIR, params.tag, it), inv_batch) # Save linear interpolation between the actual and generated encodings. if params.dist_loss and it % 1000 == 0: enc_batch, gen_enc = sess.run([encoding, gen_encoding]) for j in range(10): custom_enc = gen_enc * (1-(j/10.0)) + enc_batch * (j/10.0) sess.run(encoding.assign(custom_enc)) gen_images = sess.run(gen_img) inv_batch = vs.interleave(vs.data2img(image_batch[BATCH_SIZE - SAMPLE_SIZE:]), vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:])) inv_batch = vs.grid_transform(inv_batch) vs.save_image('{}/progress_{}_{}_lat_{}.png'.format(SAMPLES_DIR,params.tag,it,j), inv_batch) sess.run(encoding.assign(enc_batch)) # It counter. it += 1 # Save samples of inverted images. if SAMPLE_SIZE > 0: assert SAMPLE_SIZE <= BATCH_SIZE gen_images = sess.run(gen_img) inv_batch = vs.interleave(vs.data2img(image_batch[BATCH_SIZE - SAMPLE_SIZE:]), vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:])) inv_batch = vs.grid_transform(inv_batch) vs.save_image('{}/{}_{}.png'.format(SAMPLES_DIR, params.tag, out_pos), inv_batch) print('Saved samples for out_pos: {}.'.format(out_pos)) # Save images that are ready. latent_batch, enc_batch, rec_err_batch =\ sess.run([latent, encoding, img_rec_err]) out_lat[out_pos:out_pos+BATCH_SIZE] = latent_batch out_enc[out_pos:out_pos+BATCH_SIZE] = enc_batch out_images[out_pos:out_pos+BATCH_SIZE] = image_batch if COND_GAN: out_labels[out_pos:out_pos+BATCH_SIZE] = label_batch out_err[out_pos:out_pos+BATCH_SIZE] = rec_err_batch out_pos += BATCH_SIZE print('Mean reconstruction error: {}'.format(np.mean(out_err))) print('Stdev reconstruction error: {}'.format(np.std(out_err))) print('End of inversion.') out_file.close() sess.close()