summaryrefslogtreecommitdiff
path: root/cli
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-01-10 19:45:20 +0100
committerJules Laplace <julescarbon@gmail.com>2020-01-10 19:45:20 +0100
commit735b26c5b88ae1de1eac134f6d9a626ded33eeac (patch)
tree1caecff9a1bee5d2bf570a0f494998a22b91a5ba /cli
parentf462fe520aca4fd15127fd6d0b27e342e2f23a14 (diff)
load latents and labels, share interpolation amount
Diffstat (limited to 'cli')
-rw-r--r--cli/app/search/live.py73
-rw-r--r--cli/app/search/search_dense.py12
2 files changed, 58 insertions, 27 deletions
diff --git a/cli/app/search/live.py b/cli/app/search/live.py
index 21a9c5a..183bab9 100644
--- a/cli/app/search/live.py
+++ b/cli/app/search/live.py
@@ -138,7 +138,7 @@ class LerpParam:
def update(self, dt):
if self.direction != 0:
self.n.value = clamp(self.n.value + self.direction * self.speed.value * dt)
- print("set_opt: {} {}".format(self.name, self.n.value))
+ print("set_opt: {}_n {}".format(self.name, self.n.value))
if self.n.value == 0 or self.n.value == 1:
self.direction = 0
@@ -199,38 +199,58 @@ class Interpolator:
def build(self):
InterpolatorParam(name='truncation', value=1.0)
InterpolatorParam(name='num_classes', value=1.0)
- abs_zoom = InterpolatorParam(name='abs_zoom', value=1.0)
+
+ # Latent - initial lerp and wobble
lerp_z = LerpParam('latent', shape=[BATCH_SIZE, Z_DIM], datatype="noise")
sin_z = SinParam('orbit', shape=[BATCH_SIZE, Z_DIM], datatype="noise")
- lerp_label = LerpParam('label', shape=[BATCH_SIZE, N_CLASS], datatype="label")
-
z_sum = lerp_z.output + sin_z.output
+
+ # Latent - saturation
+ abs_zoom = InterpolatorParam(name='abs_zoom', value=1.0)
z_abs = z_sum / tf.abs(z_sum) * abs_zoom.variable
z_mix = LerpParam('abs_mix', a_in=z_sum, b_in=z_abs, shape=[BATCH_SIZE, Z_DIM], datatype="input")
- # self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], datatype='noise')
- # self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], datatype='label')
+ # Latent - stored vector
+ latent_stored = LerpParam(name='latent_stored', shape=[BATCH_SIZE, Z_DIM], datatype="noise")
+ latent_stored_mix = LerpParam('latent_stored_mix', a_in=z_mix.output, b_in=latent_stored.variable, shape=[BATCH_SIZE, Z_DIM], datatype="input")
+ # Label
+ lerp_label = LerpParam('label', shape=[BATCH_SIZE, N_CLASS], datatype="label")
+
+ # Latent - stored vector
+ label_stored = LerpParam(name='label_stored', shape=[BATCH_SIZE, N_CLASS], datatype="label")
+ label_stored_mix = LerpParam('label_stored_mix', a_in=z_mix.output, b_in=latent_stored.variable, shape=[BATCH_SIZE, Z_DIM], datatype="input")
+
+ # Generator
gen_in = {}
gen_in['truncation'] = 1.0 # self.opts['truncation'].variable
- gen_in['z'] = z_mix.output
- gen_in['y'] = lerp_label.output
+ gen_in['z'] = latent_stored_mix.output
+ gen_in['y'] = label_stored_mix.output
self.gen_img = generator(gen_in, signature=gen_signature)
+ # Encoding - first hidden layer
gen_layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer
encoding_latent = tf.get_default_graph().get_tensor_by_name(gen_layer_name)
encoding_shape = encoding_latent.get_shape().as_list()
- encoding_shape_specific = tuple([1,] + encoding_shape[1:])
+ encoding_shape_np = tuple([1,] + encoding_shape[1:])
+
+ encoding_shape_placeholder = tf.constant(np.zeros(encoding_shape_np, dtype=np.float32))
+ encoding_stored = LerpParam('encoding_stored', shape=encoding_shape_np, datatype="encoding")
+ encoding_stored_mix = LerpParam('encoding_stored_mix', a_in=encoding_shape_placeholder, b_in=encoding_stored.output, shape=encoding_shape_np, datatype="encoding")
- encoding_shape_placeholder = tf.constant(np.zeros(encoding_shape_specific, dtype=np.float32))
- encoding_stored = LerpParam('encoding_stored', shape=encoding_shape_specific, datatype="encoding")
- encoding_mix = LerpParam('encoding_mix', a_in=encoding_shape_placeholder, b_in=encoding_stored.output, shape=encoding_shape_specific, datatype="encoding")
- # use the placeholder to redirect parts of the graph.
+ # Use the placeholder to redirect parts of the graph.
# - computed encoding goes into the encoding_mix
# - encoding mix output goes into the main biggan graph
- tf.contrib.graph_editor.swap_ts(encoding_latent, encoding_shape_placeholder)
- tf.contrib.graph_editor.swap_ts(encoding_shape_placeholder, encoding_mix.output)
+ # We do it this way so the encoding_latent won't be going into two places at once.
+ tf.contrib.graph_editor.swap_ts(encoding_shape_placeholder, encoding_latent)
+ tf.contrib.graph_editor.swap_ts(encoding_stored_mix.output, encoding_shape_placeholder)
+
+ # Make all the stored lerps use the same interpolation amount.
+ tf.contrib.graph_editor.reroute_ts(encoding_stored.n.variable, latent_stored.n.variable)
+ tf.contrib.graph_editor.reroute_ts(encoding_stored.n.variable, label_stored.n.variable)
+ tf.contrib.graph_editor.reroute_ts(encoding_stored_mix.n.variable, latent_stored_mix.n.variable)
+ tf.contrib.graph_editor.reroute_ts(encoding_stored_mix.n.variable, label_stored_mix.n.variable)
sys.stderr.write("Sin params: {}\n".format(", ".join(self.sin_params.keys())))
sys.stderr.write("Lerp params: {}\n".format(", ".join(self.lerp_params.keys())))
@@ -273,26 +293,37 @@ class Interpolator:
def set_encoding(self, opt):
next_id = opt['id']
data = load_pickle(os.path.join(app_cfg.DIR_VECTORS, "file_{}.pkl".format(next_id)))
+
new_encoding = np.expand_dims(data['encoding'], axis=0)
new_label = np.expand_dims(data['label'], axis=0)
+ new_latent = np.expand_dims(data['latent'], axis=0)
+
+ latent_stored = self.lerp_params['latent_stored']
+ latent_stored_mix = self.lerp_params['latent_stored_mix']
+ label_stored = self.lerp_params['label_stored']
+ label_stored_mix = self.lerp_params['label_stored_mix']
encoding_stored = self.lerp_params['encoding_stored']
- encoding_mix = self.lerp_params['encoding_mix']
- label = self.lerp_params['label']
+ encoding_stored_mix = self.lerp_params['encoding_stored_mix']
+
# if we're showing an encoding already, lerp to the next one
if encoding_mix.n.value > 0:
encoding_stored.switch(target_value=new_encoding)
- label.switch(target_value=new_label)
+ label_stored.switch(target_value=new_label)
+ latent_stored.switch(target_value=new_latent)
# otherwise (we're showing the latent)...
else:
# jump to the stored encoding, then switch
if encoding_stored.n.value < 0.5:
- encoding_stored.n.value = 0
+ encoding_stored.n.assign(0)
encoding_stored.a.assign(new_encoding)
+ latent_stored.a.assign(new_latent)
+ label_stored.a.assign(new_label)
else:
- encoding_stored.n.value = 1
+ encoding_stored.n.assign(1)
encoding_stored.b.assign(new_encoding)
+ latent_stored.b.assign(new_latent)
+ label_stored.b.assign(new_label)
encoding_mix.switch()
- label.switch(target_value=new_label)
def on_step(self, i, dt, sess):
for param in self.sin_params.values():
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py
index c7cf078..616ba6a 100644
--- a/cli/app/search/search_dense.py
+++ b/cli/app/search/search_dense.py
@@ -280,7 +280,7 @@ def find_dense_embedding_for_images(params):
params.inv_it / params.decay_n, 0.1, staircase=True)
else:
lrate = tf.constant(params.lr)
- trained_params = [encoding] if params.fixed_z else [latent, encoding]
+ trained_params = [latent, encoding]
optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999)
inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params,
global_step=inv_step)
@@ -406,12 +406,12 @@ def find_dense_embedding_for_images(params):
it += 1
# Save images that are ready.
- latent_batch, enc_batch, rec_err_batch = sess.run([latent, encoding, img_rec_err])
- out_lat[out_pos:out_pos+BATCH_SIZE] = latent_batch
- out_enc[out_pos:out_pos+BATCH_SIZE] = enc_batch
+ label_trained, latent_trained, enc_trained, rec_err_trained = sess.run([label, latent, encoding, img_rec_err])
+ out_lat[out_pos:out_pos+BATCH_SIZE] = latent_trained
+ out_enc[out_pos:out_pos+BATCH_SIZE] = enc_trained
out_images[out_pos:out_pos+BATCH_SIZE] = image_batch
- out_labels[out_pos:out_pos+BATCH_SIZE] = label_batch
- out_err[out_pos:out_pos+BATCH_SIZE] = rec_err_batch
+ out_labels[out_pos:out_pos+BATCH_SIZE] = label_trained
+ out_err[out_pos:out_pos+BATCH_SIZE] = rec_err_trained
gen_images = sess.run(gen_img_orig)
images = vs.data2img(gen_images)