From ed24122b4b205455c5672b39c4180e2af5b31774 Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Mon, 6 Jan 2020 15:20:40 +0100 Subject: path --- cli/app/commands/biggan/search_class.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) (limited to 'cli/app/commands') diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py index 9076e26..8895a67 100644 --- a/cli/app/commands/biggan/search_class.py +++ b/cli/app/commands/biggan/search_class.py @@ -43,8 +43,6 @@ def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag): """ sess = tf.compat.v1.Session() - sess.run(tf.compat.v1.global_variables_initializer()) - sess.run(tf.compat.v1.tables_initializer()) if os.path.isdir(opt_fp_in): paths = glob(os.path.join(opt_fp_in, '*.jpg')) + \ @@ -95,17 +93,17 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2)) input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1)) # input_trunc = tf.compat.v1.constant(1.0) - input_trunc = tf.placeholder(tf.float32, shape=None) + input_trunc = tf.compat.v1.placeholder(tf.float32, shape=None) output = generator({ 'z': input_z, 'y': input_y, 'truncation': input_trunc, }) - target = tf.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels)) + target = tf.compat.v1.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels)) # loss = tf.losses.compute_weighted_loss(tf.square(output - target), weights=mask) - loss = tf.losses.mean_squared_error(target, output) + loss = tf.compat.v1.losses.mean_squared_error(target, output) train_step_z = tf.train.AdamOptimizer(z_lr).minimize(loss, var_list=[input_z], name='AdamOpterZ') train_step_y = tf.train.AdamOptimizer(y_lr).minimize(loss, var_list=[input_y], name='AdamOpterY') @@ -123,6 +121,10 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l phi_target = np.expand_dims(phi_target, 0) phi_target = np.repeat(phi_target, batch_size, axis=0) + # IMPORTANT: initialize variables before running the session + sess.run(tf.compat.v1.global_variables_initializer()) + sess.run(tf.compat.v1.tables_initializer()) + # feed_dict = {input_z: z, input_y: y, input_trunc: truncation} phi_start = sess.run(output, { input_trunc: truncation, -- cgit v1.2.3-70-g09d2