diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2020-01-10 19:45:20 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2020-01-10 19:45:20 +0100 |
| commit | 735b26c5b88ae1de1eac134f6d9a626ded33eeac (patch) | |
| tree | 1caecff9a1bee5d2bf570a0f494998a22b91a5ba /cli | |
| parent | f462fe520aca4fd15127fd6d0b27e342e2f23a14 (diff) | |
load latents and labels, share interpolation amount
Diffstat (limited to 'cli')
| -rw-r--r-- | cli/app/search/live.py | 73 | ||||
| -rw-r--r-- | cli/app/search/search_dense.py | 12 |
2 files changed, 58 insertions, 27 deletions
diff --git a/cli/app/search/live.py b/cli/app/search/live.py index 21a9c5a..183bab9 100644 --- a/cli/app/search/live.py +++ b/cli/app/search/live.py @@ -138,7 +138,7 @@ class LerpParam: def update(self, dt): if self.direction != 0: self.n.value = clamp(self.n.value + self.direction * self.speed.value * dt) - print("set_opt: {} {}".format(self.name, self.n.value)) + print("set_opt: {}_n {}".format(self.name, self.n.value)) if self.n.value == 0 or self.n.value == 1: self.direction = 0 @@ -199,38 +199,58 @@ class Interpolator: def build(self): InterpolatorParam(name='truncation', value=1.0) InterpolatorParam(name='num_classes', value=1.0) - abs_zoom = InterpolatorParam(name='abs_zoom', value=1.0) + + # Latent - initial lerp and wobble lerp_z = LerpParam('latent', shape=[BATCH_SIZE, Z_DIM], datatype="noise") sin_z = SinParam('orbit', shape=[BATCH_SIZE, Z_DIM], datatype="noise") - lerp_label = LerpParam('label', shape=[BATCH_SIZE, N_CLASS], datatype="label") - z_sum = lerp_z.output + sin_z.output + + # Latent - saturation + abs_zoom = InterpolatorParam(name='abs_zoom', value=1.0) z_abs = z_sum / tf.abs(z_sum) * abs_zoom.variable z_mix = LerpParam('abs_mix', a_in=z_sum, b_in=z_abs, shape=[BATCH_SIZE, Z_DIM], datatype="input") - # self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], datatype='noise') - # self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], datatype='label') + # Latent - stored vector + latent_stored = LerpParam(name='latent_stored', shape=[BATCH_SIZE, Z_DIM], datatype="noise") + latent_stored_mix = LerpParam('latent_stored_mix', a_in=z_mix.output, b_in=latent_stored.variable, shape=[BATCH_SIZE, Z_DIM], datatype="input") + # Label + lerp_label = LerpParam('label', shape=[BATCH_SIZE, N_CLASS], datatype="label") + + # Latent - stored vector + label_stored = LerpParam(name='label_stored', shape=[BATCH_SIZE, N_CLASS], datatype="label") + label_stored_mix = LerpParam('label_stored_mix', a_in=z_mix.output, b_in=latent_stored.variable, shape=[BATCH_SIZE, Z_DIM], datatype="input") + + # Generator gen_in = {} gen_in['truncation'] = 1.0 # self.opts['truncation'].variable - gen_in['z'] = z_mix.output - gen_in['y'] = lerp_label.output + gen_in['z'] = latent_stored_mix.output + gen_in['y'] = label_stored_mix.output self.gen_img = generator(gen_in, signature=gen_signature) + # Encoding - first hidden layer gen_layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer encoding_latent = tf.get_default_graph().get_tensor_by_name(gen_layer_name) encoding_shape = encoding_latent.get_shape().as_list() - encoding_shape_specific = tuple([1,] + encoding_shape[1:]) + encoding_shape_np = tuple([1,] + encoding_shape[1:]) + + encoding_shape_placeholder = tf.constant(np.zeros(encoding_shape_np, dtype=np.float32)) + encoding_stored = LerpParam('encoding_stored', shape=encoding_shape_np, datatype="encoding") + encoding_stored_mix = LerpParam('encoding_stored_mix', a_in=encoding_shape_placeholder, b_in=encoding_stored.output, shape=encoding_shape_np, datatype="encoding") - encoding_shape_placeholder = tf.constant(np.zeros(encoding_shape_specific, dtype=np.float32)) - encoding_stored = LerpParam('encoding_stored', shape=encoding_shape_specific, datatype="encoding") - encoding_mix = LerpParam('encoding_mix', a_in=encoding_shape_placeholder, b_in=encoding_stored.output, shape=encoding_shape_specific, datatype="encoding") - # use the placeholder to redirect parts of the graph. + # Use the placeholder to redirect parts of the graph. # - computed encoding goes into the encoding_mix # - encoding mix output goes into the main biggan graph - tf.contrib.graph_editor.swap_ts(encoding_latent, encoding_shape_placeholder) - tf.contrib.graph_editor.swap_ts(encoding_shape_placeholder, encoding_mix.output) + # We do it this way so the encoding_latent won't be going into two places at once. + tf.contrib.graph_editor.swap_ts(encoding_shape_placeholder, encoding_latent) + tf.contrib.graph_editor.swap_ts(encoding_stored_mix.output, encoding_shape_placeholder) + + # Make all the stored lerps use the same interpolation amount. + tf.contrib.graph_editor.reroute_ts(encoding_stored.n.variable, latent_stored.n.variable) + tf.contrib.graph_editor.reroute_ts(encoding_stored.n.variable, label_stored.n.variable) + tf.contrib.graph_editor.reroute_ts(encoding_stored_mix.n.variable, latent_stored_mix.n.variable) + tf.contrib.graph_editor.reroute_ts(encoding_stored_mix.n.variable, label_stored_mix.n.variable) sys.stderr.write("Sin params: {}\n".format(", ".join(self.sin_params.keys()))) sys.stderr.write("Lerp params: {}\n".format(", ".join(self.lerp_params.keys()))) @@ -273,26 +293,37 @@ class Interpolator: def set_encoding(self, opt): next_id = opt['id'] data = load_pickle(os.path.join(app_cfg.DIR_VECTORS, "file_{}.pkl".format(next_id))) + new_encoding = np.expand_dims(data['encoding'], axis=0) new_label = np.expand_dims(data['label'], axis=0) + new_latent = np.expand_dims(data['latent'], axis=0) + + latent_stored = self.lerp_params['latent_stored'] + latent_stored_mix = self.lerp_params['latent_stored_mix'] + label_stored = self.lerp_params['label_stored'] + label_stored_mix = self.lerp_params['label_stored_mix'] encoding_stored = self.lerp_params['encoding_stored'] - encoding_mix = self.lerp_params['encoding_mix'] - label = self.lerp_params['label'] + encoding_stored_mix = self.lerp_params['encoding_stored_mix'] + # if we're showing an encoding already, lerp to the next one if encoding_mix.n.value > 0: encoding_stored.switch(target_value=new_encoding) - label.switch(target_value=new_label) + label_stored.switch(target_value=new_label) + latent_stored.switch(target_value=new_latent) # otherwise (we're showing the latent)... else: # jump to the stored encoding, then switch if encoding_stored.n.value < 0.5: - encoding_stored.n.value = 0 + encoding_stored.n.assign(0) encoding_stored.a.assign(new_encoding) + latent_stored.a.assign(new_latent) + label_stored.a.assign(new_label) else: - encoding_stored.n.value = 1 + encoding_stored.n.assign(1) encoding_stored.b.assign(new_encoding) + latent_stored.b.assign(new_latent) + label_stored.b.assign(new_label) encoding_mix.switch() - label.switch(target_value=new_label) def on_step(self, i, dt, sess): for param in self.sin_params.values(): diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py index c7cf078..616ba6a 100644 --- a/cli/app/search/search_dense.py +++ b/cli/app/search/search_dense.py @@ -280,7 +280,7 @@ def find_dense_embedding_for_images(params): params.inv_it / params.decay_n, 0.1, staircase=True) else: lrate = tf.constant(params.lr) - trained_params = [encoding] if params.fixed_z else [latent, encoding] + trained_params = [latent, encoding] optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999) inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params, global_step=inv_step) @@ -406,12 +406,12 @@ def find_dense_embedding_for_images(params): it += 1 # Save images that are ready. - latent_batch, enc_batch, rec_err_batch = sess.run([latent, encoding, img_rec_err]) - out_lat[out_pos:out_pos+BATCH_SIZE] = latent_batch - out_enc[out_pos:out_pos+BATCH_SIZE] = enc_batch + label_trained, latent_trained, enc_trained, rec_err_trained = sess.run([label, latent, encoding, img_rec_err]) + out_lat[out_pos:out_pos+BATCH_SIZE] = latent_trained + out_enc[out_pos:out_pos+BATCH_SIZE] = enc_trained out_images[out_pos:out_pos+BATCH_SIZE] = image_batch - out_labels[out_pos:out_pos+BATCH_SIZE] = label_batch - out_err[out_pos:out_pos+BATCH_SIZE] = rec_err_batch + out_labels[out_pos:out_pos+BATCH_SIZE] = label_trained + out_err[out_pos:out_pos+BATCH_SIZE] = rec_err_trained gen_images = sess.run(gen_img_orig) images = vs.data2img(gen_images) |
