summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-01-06 15:20:40 +0100
committerJules Laplace <julescarbon@gmail.com>2020-01-06 15:20:40 +0100
commited24122b4b205455c5672b39c4180e2af5b31774 (patch)
treec0b87c5d4333927b004901d9ae8c3988971d5fa1
parent4f625c70a4975654868c4c536def6ede472f4a93 (diff)
path
-rw-r--r--cli/app/commands/biggan/search_class.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py
index 9076e26..8895a67 100644
--- a/cli/app/commands/biggan/search_class.py
+++ b/cli/app/commands/biggan/search_class.py
@@ -43,8 +43,6 @@ def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag):
"""
sess = tf.compat.v1.Session()
- sess.run(tf.compat.v1.global_variables_initializer())
- sess.run(tf.compat.v1.tables_initializer())
if os.path.isdir(opt_fp_in):
paths = glob(os.path.join(opt_fp_in, '*.jpg')) + \
@@ -95,17 +93,17 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2))
input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1))
# input_trunc = tf.compat.v1.constant(1.0)
- input_trunc = tf.placeholder(tf.float32, shape=None)
+ input_trunc = tf.compat.v1.placeholder(tf.float32, shape=None)
output = generator({
'z': input_z,
'y': input_y,
'truncation': input_trunc,
})
- target = tf.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels))
+ target = tf.compat.v1.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels))
# loss = tf.losses.compute_weighted_loss(tf.square(output - target), weights=mask)
- loss = tf.losses.mean_squared_error(target, output)
+ loss = tf.compat.v1.losses.mean_squared_error(target, output)
train_step_z = tf.train.AdamOptimizer(z_lr).minimize(loss, var_list=[input_z], name='AdamOpterZ')
train_step_y = tf.train.AdamOptimizer(y_lr).minimize(loss, var_list=[input_y], name='AdamOpterY')
@@ -123,6 +121,10 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
phi_target = np.expand_dims(phi_target, 0)
phi_target = np.repeat(phi_target, batch_size, axis=0)
+ # IMPORTANT: initialize variables before running the session
+ sess.run(tf.compat.v1.global_variables_initializer())
+ sess.run(tf.compat.v1.tables_initializer())
+
# feed_dict = {input_z: z, input_y: y, input_trunc: truncation}
phi_start = sess.run(output, {
input_trunc: truncation,