# ------------------------------------------------------------------------------ # Cluster the attention map of inverted images for unsupervised segmentation. # ------------------------------------------------------------------------------ import h5py import numpy as np import os import params import scipy import scipy.cluster.hierarchy from sklearn.cluster import AgglomerativeClustering import sys import tensorflow as tf import tensorflow_hub as hub import time import visualize as vs if len(sys.argv) < 2: sys.exit('Must provide a configuration file.') params = params.Params(sys.argv[1]) params.batch_size = 1 params.sample_size = 1 # -------------------------- # Global directories. # -------------------------- BATCH_SIZE = params.batch_size SAMPLE_SIZE = params.sample_size SAMPLES_DIR = 'attention' INVERSES_DIR = 'inverses' if not os.path.exists(SAMPLES_DIR): os.makedirs(SAMPLES_DIR) # -------------------------- # Util functions. # -------------------------- # One hot encoding for classes. def one_hot(values): return np.eye(N_CLASS)[values] def segment_img(diss_matrix, n_clusters): # Cluster image based on the information from the attention map. clustering = AgglomerativeClustering(n_clusters=n_clusters, affinity='precomputed', linkage='average') clustering.fit(diss_matrix) labels = clustering.labels_ # Upsample segmentation (from 64x64 to 128x128) and create an image where each # segment has the average color of its members. labels = np.broadcast_to(labels.reshape(64, 1, 64, 1), (64, 2, 64, 2))\ .reshape(128*128) labels = np.eye(labels.max() + 1)[labels] cluster_col = np.matmul(labels.T, np.transpose(_gen_img, [0, 2, 3, 1]).reshape(128*128, 3)) cluster_count = labels.T.sum(axis=1).reshape(-1, 1) labels_img = np.matmul(labels, cluster_col) / np.matmul(labels, cluster_count) labels_img = np.transpose(labels_img, [1, 0]).reshape(1,3,128,128) return vs.data2img(labels_img) # -------------------------- # Load Graph. # -------------------------- generator = hub.Module(str(params.generator_path)) gen_signature = 'generator' if 'generator' not in generator.get_signature_names(): gen_signature = 'default' input_info = generator.get_input_info_dict(gen_signature) COND_GAN = 'y' in input_info if COND_GAN: Z_DIM = input_info['z'].get_shape().as_list()[1] latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) N_CLASS = input_info['y'].get_shape().as_list()[1] label = tf.get_variable(name='label', dtype=tf.float32, shape=[BATCH_SIZE, N_CLASS]) gen_in = dict(params.generator_fixed_inputs) gen_in['z'] = latent gen_in['y'] = label gen_img = generator(gen_in, signature=gen_signature) else: Z_DIM = input_info['default'].get_shape().as_list()[1] latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) if (params.generator_fixed_inputs): gen_in = dict(params.generator_fixed_inputs) gen_in['z'] = latent gen_img = generator(gen_in, signature=gen_signature) else: gen_img = generator(latent, signature=gen_signature) # Convert generated image to channels_first. gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) # Override intermediate layer. if params.inv_layer == 'latent': encoding = latent ENC_SHAPE = [Z_DIM] else: layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] encoding = tf.get_variable(name='encoding', dtype=tf.float32, shape=[BATCH_SIZE,] + ENC_SHAPE) tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) # Get attention map. att_map_name = 'module_apply_' + gen_signature + '/' + params.attention_map_layer att_map = tf.get_default_graph().get_tensor_by_name(att_map_name) # Define image shape. IMG_SHAPE = gen_img.get_shape().as_list()[1:] # -------------------------- # Dataset. # -------------------------- if params.out_dataset.endswith('.hdf5'): in_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'r') sample_images = in_file['xtrain'] if COND_GAN: sample_labels = in_file['ytrain'] sample_latents = in_file['latent'] sample_encodings = in_file['encoding'] NUM_IMGS = sample_images.shape[0] # number of images. def sample_images_gen(): for i in xrange(NUM_IMGS / BATCH_SIZE): i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE if COND_GAN: label_batch = sample_labels[i_1:i_2] else: label_batch = np.zeros(BATCH_SIZE) yield sample_images[i_1:i_2], label_batch, sample_latents[i_1:i_2],\ sample_encodings[i_1:i_2] image_gen = sample_images_gen() else: sys.exit('Unknown dataset {}.'.format(params.out_dataset)) NUM_IMGS -= NUM_IMGS % BATCH_SIZE # -------------------------- # Training. # -------------------------- # Start session. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) # Export attention map for reconstructed images. it = 0 out_pos = 0 start_time = time.time() for image_batch, label_batch, lat_batch, enc_batch in image_gen: # Set target label. if COND_GAN: sess.run(label.assign(one_hot(label_batch))) # Initialize encodings. sess.run(latent.assign(lat_batch)) sess.run(encoding.assign(enc_batch)) # Get attention map. _att_map, _gen_img = sess.run([att_map, gen_img]) # Upsampling (from 32x32 to 64x64). _att_map = np.broadcast_to(_att_map.reshape(64,64,32,1,32,1), (64,64,32,2,32,2)).reshape(4096,4096) # Define dissimilarity matrix. dissimilarity = 1.0 - (_att_map + _att_map.T) / 2.0 dissimilarity *= (np.ones((4096,4096)) - np.identity(4096)) # Segment the image with different number of clusters. seg_img_8 = segment_img(dissimilarity, 8) seg_img_20 = segment_img(dissimilarity, 20) seg_img_40 = segment_img(dissimilarity, 40) # Save segmentation. out_batch_1 = vs.interleave(image_batch, seg_img_20) out_batch_2 = vs.interleave(seg_img_8, seg_img_40) out_batch = vs.interleave(out_batch_1, out_batch_2) out_batch = vs.seq_transform(out_batch) vs.save_image('{}/segmented_img_{}.png'.format(SAMPLES_DIR, it), out_batch) it += 1 sess.close()