summaryrefslogtreecommitdiff
path: root/cli
diff options
context:
space:
mode:
Diffstat (limited to 'cli')
-rw-r--r--cli/app/commands/biggan/search_class.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py
index 9076e26..8895a67 100644
--- a/cli/app/commands/biggan/search_class.py
+++ b/cli/app/commands/biggan/search_class.py
@@ -43,8 +43,6 @@ def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag):
"""
sess = tf.compat.v1.Session()
- sess.run(tf.compat.v1.global_variables_initializer())
- sess.run(tf.compat.v1.tables_initializer())
if os.path.isdir(opt_fp_in):
paths = glob(os.path.join(opt_fp_in, '*.jpg')) + \
@@ -95,17 +93,17 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2))
input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1))
# input_trunc = tf.compat.v1.constant(1.0)
- input_trunc = tf.placeholder(tf.float32, shape=None)
+ input_trunc = tf.compat.v1.placeholder(tf.float32, shape=None)
output = generator({
'z': input_z,
'y': input_y,
'truncation': input_trunc,
})
- target = tf.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels))
+ target = tf.compat.v1.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels))
# loss = tf.losses.compute_weighted_loss(tf.square(output - target), weights=mask)
- loss = tf.losses.mean_squared_error(target, output)
+ loss = tf.compat.v1.losses.mean_squared_error(target, output)
train_step_z = tf.train.AdamOptimizer(z_lr).minimize(loss, var_list=[input_z], name='AdamOpterZ')
train_step_y = tf.train.AdamOptimizer(y_lr).minimize(loss, var_list=[input_y], name='AdamOpterY')
@@ -123,6 +121,10 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
phi_target = np.expand_dims(phi_target, 0)
phi_target = np.repeat(phi_target, batch_size, axis=0)
+ # IMPORTANT: initialize variables before running the session
+ sess.run(tf.compat.v1.global_variables_initializer())
+ sess.run(tf.compat.v1.tables_initializer())
+
# feed_dict = {input_z: z, input_y: y, input_trunc: truncation}
phi_start = sess.run(output, {
input_trunc: truncation,