summaryrefslogtreecommitdiff
path: root/cli/app/search/search_dense.py
diff options
context:
space:
mode:
Diffstat (limited to 'cli/app/search/search_dense.py')
-rw-r--r--cli/app/search/search_dense.py74
1 files changed, 59 insertions, 15 deletions
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py
index 04c99ab..992eb93 100644
--- a/cli/app/search/search_dense.py
+++ b/cli/app/search/search_dense.py
@@ -81,11 +81,9 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
if COND_GAN:
Z_DIM = input_info['z'].get_shape().as_list()[1]
- latent = tf.get_variable(name='latent', dtype=tf.float32,
- shape=[BATCH_SIZE, Z_DIM])
+ latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM])
N_CLASS = input_info['y'].get_shape().as_list()[1]
- label = tf.get_variable(name='label', dtype=tf.float32,
- shape=[BATCH_SIZE, N_CLASS])
+ label = tf.get_variable(name='label', dtype=tf.float32, shape=[BATCH_SIZE, N_CLASS])
gen_in = dict(params.generator_fixed_inputs)
gen_in['z'] = latent
gen_in['y'] = label
@@ -114,10 +112,31 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer
gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name)
ENC_SHAPE = gen_encoding.get_shape().as_list()[1:]
- encoding = tf.get_variable(name='encoding', dtype=tf.float32,
- shape=[BATCH_SIZE,] + ENC_SHAPE)
+ encoding = tf.get_variable(name='encoding', dtype=tf.float32, shape=[BATCH_SIZE,] + ENC_SHAPE)
tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding))
+ layer_label_variables = []
+ gen_label = tf.get_default_graph().get_tensor_by_name('module_apply_{}/linear_1/MatMul:0'.format(gen_signature))
+ if params.invert_labels:
+ op_names = [
+ "Generator_2/concat",
+ "Generator_2/concat_1",
+ "Generator_2/concat_2",
+ "Generator_2/concat_3",
+ "Generator_2/concat_4",
+ "Generator_2/concat_5",
+ "Generator_2/concat_6",
+ ]
+ op_input_index = 1
+ layer_shape = [1, 128,]
+ for op_name in op_names:
+ layer_name = 'module_apply_{}/{}'.format(gen_signature, op_name)
+ variable_name = op_name + "_label"
+ raw_op = tf.get_default_graph().get_op_by_name(layer_name)
+ new_op_input = tf.get_variable(name=variable_name, dtype=tf.float32, shape=[BATCH_SIZE,] + layer_shape)
+ op._update_input(op_input_index, new_op_input)
+ layer_label_variables.append(new_op_input)
+
# Step counter.
inv_step = tf.get_variable('inv_step', initializer=0, trainable=False)
@@ -130,8 +149,7 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
# Custom Gradient for Relus.
if params.custom_grad_relu:
- grad_lambda = tf.train.exponential_decay(0.1, inv_step, params.inv_it / 5,
- 0.1, staircase=False)
+ grad_lambda = tf.train.exponential_decay(0.1, inv_step, params.inv_it / 5, 0.1, staircase=False)
@tf.custom_gradient
def relu_custom_grad(x):
def grad(dy):
@@ -219,7 +237,14 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
lrate = tf.constant(params.lr)
# trained_params = [label, latent, encoding]
- trained_params = [latent, encoding]
+ # trained_params = [latent, encoding]
+ if params.inv_layer == 'latent':
+ trained_params = [latent]
+ else:
+ trained_params = [latent, encoding]
+
+ if params.invert_labels:
+ trained_params += layer_label_variables
optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999)
inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params,
@@ -310,11 +335,22 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
latent.assign(next(latent_gen)),
inv_step.assign(0),
])
- sess.run([
- encoding.assign(gen_encoding),
+
+ encoding_init_funcs = [
reinit_optimizer,
reinit_optimizer_quad,
- ])
+ ]
+
+ if params.inv_layer != 'latent':
+ encoding_init_funcs += [
+ encoding.assign(gen_encoding),
+ ]
+
+ if params.invert_labels:
+ for layer_label in layer_label_variables:
+ encoding_init_funcs.append(gen_label.assign(layer_label))
+
+ sess.run(encoding_init_funcs)
# Main optimization loop.
print("Beginning dense iteration...")
@@ -356,7 +392,11 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
it += 1
# Save images that are ready.
- label_trained, latent_trained, enc_trained, rec_err_trained = sess.run([label, latent, encoding, img_rec_err])
+ label_trained, latent_trained = sess.run([label, latent])
+ if params.inv_layer != 'latent':
+ enc_trained = sess.run([encoding])
+ if params.invert_labels:
+ layer_labels_trained = sess.run(layer_label_variables)
gen_images = sess.run(gen_img_orig)
images = vs.data2img(gen_images)
@@ -382,14 +422,18 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
'sample_fn': sample_fn,
'label': label_trained[i],
'latent': latent_trained[i],
- 'encoding': enc_trained[i],
}
+ if params.inv_layer != 'latent':
+ out_data['encoding'] = enc_trained[i]
+ if params.invert_labels:
+ out_data['layer_labels'] = []
+ for layer in layer_labels_trained:
+ out_data['layer_labels'].append(layer[i])
write_pickle(out_data, fp_out_pkl)
out_lat[out_i] = latent_trained[i]
out_enc[out_i] = enc_trained[i]
out_images[out_i] = image_batch[i]
out_labels[out_i] = label_trained[i]
- out_err[out_i] = rec_err_trained[i]
out_pos += BATCH_SIZE
if params.max_batches > 0 and (out_pos / BATCH_SIZE) >= params.max_batches: