diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2020-01-06 15:20:40 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2020-01-06 15:20:40 +0100 |
| commit | ed24122b4b205455c5672b39c4180e2af5b31774 (patch) | |
| tree | c0b87c5d4333927b004901d9ae8c3988971d5fa1 /cli/app/commands/biggan/search_class.py | |
| parent | 4f625c70a4975654868c4c536def6ede472f4a93 (diff) | |
path
Diffstat (limited to 'cli/app/commands/biggan/search_class.py')
| -rw-r--r-- | cli/app/commands/biggan/search_class.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py index 9076e26..8895a67 100644 --- a/cli/app/commands/biggan/search_class.py +++ b/cli/app/commands/biggan/search_class.py @@ -43,8 +43,6 @@ def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag): """ sess = tf.compat.v1.Session() - sess.run(tf.compat.v1.global_variables_initializer()) - sess.run(tf.compat.v1.tables_initializer()) if os.path.isdir(opt_fp_in): paths = glob(os.path.join(opt_fp_in, '*.jpg')) + \ @@ -95,17 +93,17 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2)) input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1)) # input_trunc = tf.compat.v1.constant(1.0) - input_trunc = tf.placeholder(tf.float32, shape=None) + input_trunc = tf.compat.v1.placeholder(tf.float32, shape=None) output = generator({ 'z': input_z, 'y': input_y, 'truncation': input_trunc, }) - target = tf.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels)) + target = tf.compat.v1.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels)) # loss = tf.losses.compute_weighted_loss(tf.square(output - target), weights=mask) - loss = tf.losses.mean_squared_error(target, output) + loss = tf.compat.v1.losses.mean_squared_error(target, output) train_step_z = tf.train.AdamOptimizer(z_lr).minimize(loss, var_list=[input_z], name='AdamOpterZ') train_step_y = tf.train.AdamOptimizer(y_lr).minimize(loss, var_list=[input_y], name='AdamOpterY') @@ -123,6 +121,10 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l phi_target = np.expand_dims(phi_target, 0) phi_target = np.repeat(phi_target, batch_size, axis=0) + # IMPORTANT: initialize variables before running the session + sess.run(tf.compat.v1.global_variables_initializer()) + sess.run(tf.compat.v1.tables_initializer()) + # feed_dict = {input_z: z, input_y: y, input_trunc: truncation} phi_start = sess.run(output, { input_trunc: truncation, |
