diff options
Diffstat (limited to 'cli/app/search/search_dense.py')
| -rw-r--r-- | cli/app/search/search_dense.py | 74 |
1 files changed, 59 insertions, 15 deletions
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py index 04c99ab..992eb93 100644 --- a/cli/app/search/search_dense.py +++ b/cli/app/search/search_dense.py @@ -81,11 +81,9 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op if COND_GAN: Z_DIM = input_info['z'].get_shape().as_list()[1] - latent = tf.get_variable(name='latent', dtype=tf.float32, - shape=[BATCH_SIZE, Z_DIM]) + latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) N_CLASS = input_info['y'].get_shape().as_list()[1] - label = tf.get_variable(name='label', dtype=tf.float32, - shape=[BATCH_SIZE, N_CLASS]) + label = tf.get_variable(name='label', dtype=tf.float32, shape=[BATCH_SIZE, N_CLASS]) gen_in = dict(params.generator_fixed_inputs) gen_in['z'] = latent gen_in['y'] = label @@ -114,10 +112,31 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] - encoding = tf.get_variable(name='encoding', dtype=tf.float32, - shape=[BATCH_SIZE,] + ENC_SHAPE) + encoding = tf.get_variable(name='encoding', dtype=tf.float32, shape=[BATCH_SIZE,] + ENC_SHAPE) tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) + layer_label_variables = [] + gen_label = tf.get_default_graph().get_tensor_by_name('module_apply_{}/linear_1/MatMul:0'.format(gen_signature)) + if params.invert_labels: + op_names = [ + "Generator_2/concat", + "Generator_2/concat_1", + "Generator_2/concat_2", + "Generator_2/concat_3", + "Generator_2/concat_4", + "Generator_2/concat_5", + "Generator_2/concat_6", + ] + op_input_index = 1 + layer_shape = [1, 128,] + for op_name in op_names: + layer_name = 'module_apply_{}/{}'.format(gen_signature, op_name) + variable_name = op_name + "_label" + raw_op = tf.get_default_graph().get_op_by_name(layer_name) + new_op_input = tf.get_variable(name=variable_name, dtype=tf.float32, shape=[BATCH_SIZE,] + layer_shape) + op._update_input(op_input_index, new_op_input) + layer_label_variables.append(new_op_input) + # Step counter. inv_step = tf.get_variable('inv_step', initializer=0, trainable=False) @@ -130,8 +149,7 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op # Custom Gradient for Relus. if params.custom_grad_relu: - grad_lambda = tf.train.exponential_decay(0.1, inv_step, params.inv_it / 5, - 0.1, staircase=False) + grad_lambda = tf.train.exponential_decay(0.1, inv_step, params.inv_it / 5, 0.1, staircase=False) @tf.custom_gradient def relu_custom_grad(x): def grad(dy): @@ -219,7 +237,14 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op lrate = tf.constant(params.lr) # trained_params = [label, latent, encoding] - trained_params = [latent, encoding] + # trained_params = [latent, encoding] + if params.inv_layer == 'latent': + trained_params = [latent] + else: + trained_params = [latent, encoding] + + if params.invert_labels: + trained_params += layer_label_variables optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999) inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params, @@ -310,11 +335,22 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op latent.assign(next(latent_gen)), inv_step.assign(0), ]) - sess.run([ - encoding.assign(gen_encoding), + + encoding_init_funcs = [ reinit_optimizer, reinit_optimizer_quad, - ]) + ] + + if params.inv_layer != 'latent': + encoding_init_funcs += [ + encoding.assign(gen_encoding), + ] + + if params.invert_labels: + for layer_label in layer_label_variables: + encoding_init_funcs.append(gen_label.assign(layer_label)) + + sess.run(encoding_init_funcs) # Main optimization loop. print("Beginning dense iteration...") @@ -356,7 +392,11 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op it += 1 # Save images that are ready. - label_trained, latent_trained, enc_trained, rec_err_trained = sess.run([label, latent, encoding, img_rec_err]) + label_trained, latent_trained = sess.run([label, latent]) + if params.inv_layer != 'latent': + enc_trained = sess.run([encoding]) + if params.invert_labels: + layer_labels_trained = sess.run(layer_label_variables) gen_images = sess.run(gen_img_orig) images = vs.data2img(gen_images) @@ -382,14 +422,18 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op 'sample_fn': sample_fn, 'label': label_trained[i], 'latent': latent_trained[i], - 'encoding': enc_trained[i], } + if params.inv_layer != 'latent': + out_data['encoding'] = enc_trained[i] + if params.invert_labels: + out_data['layer_labels'] = [] + for layer in layer_labels_trained: + out_data['layer_labels'].append(layer[i]) write_pickle(out_data, fp_out_pkl) out_lat[out_i] = latent_trained[i] out_enc[out_i] = enc_trained[i] out_images[out_i] = image_batch[i] out_labels[out_i] = label_trained[i] - out_err[out_i] = rec_err_trained[i] out_pos += BATCH_SIZE if params.max_batches > 0 and (out_pos / BATCH_SIZE) >= params.max_batches: |
