summaryrefslogtreecommitdiff
path: root/cli
diff options
context:
space:
mode:
Diffstat (limited to 'cli')
-rw-r--r--cli/app/commands/biggan/search_class.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py
index 78b7b2d..4c797ae 100644
--- a/cli/app/commands/biggan/search_class.py
+++ b/cli/app/commands/biggan/search_class.py
@@ -41,7 +41,6 @@ def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag):
"""
Search for an image (class vector) in BigGAN using gradient descent
"""
- vocab_size = input_y.shape.as_list()[1]
sess = tf.compat.v1.Session()
sess.run(tf.compat.v1.global_variables_initializer())
@@ -60,7 +59,7 @@ def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag):
save_params_dense(fp_inverses, opt_tag)
out_file = h5py.File(join(fp_inverses, 'dataset.hdf5'), 'w')
out_images = out_file.create_dataset('xtrain', (len(paths), 3, 512, 512,), dtype='float32')
- out_labels = out_file.create_dataset('ytrain', (len(paths), vocab_size,), dtype='float32')
+ out_labels = out_file.create_dataset('ytrain', (len(paths), 1000,), dtype='float32')
out_latent = out_file.create_dataset('ztrain', (len(paths), 128,), dtype='float32')
out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype())
for index, path in enumerate(paths):
@@ -82,8 +81,9 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
batch_size = 1
truncation = 1.0
- z_dim = 512
+ z_dim = 128
vocab_size = 1000
+ img_size = 512
num_channels = 3
z_initial = truncated_z_sample(batch_size, z_dim, truncation/2)
@@ -101,7 +101,7 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
'truncation': input_trunc,
})
- target = tf.placeholder(tf.float32, shape=(batch_size, z_dim, z_dim, num_channels))
+ target = tf.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels))
# loss = tf.losses.compute_weighted_loss(tf.square(output - target), weights=mask)
loss = tf.losses.mean_squared_error(target, output)