diff options
Diffstat (limited to 'cli/app/commands/biggan/search_class.py')
| -rw-r--r-- | cli/app/commands/biggan/search_class.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py index 78b7b2d..4c797ae 100644 --- a/cli/app/commands/biggan/search_class.py +++ b/cli/app/commands/biggan/search_class.py @@ -41,7 +41,6 @@ def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag): """ Search for an image (class vector) in BigGAN using gradient descent """ - vocab_size = input_y.shape.as_list()[1] sess = tf.compat.v1.Session() sess.run(tf.compat.v1.global_variables_initializer()) @@ -60,7 +59,7 @@ def cli(ctx, opt_fp_in, opt_dims, opt_steps, opt_limit, opt_video, opt_tag): save_params_dense(fp_inverses, opt_tag) out_file = h5py.File(join(fp_inverses, 'dataset.hdf5'), 'w') out_images = out_file.create_dataset('xtrain', (len(paths), 3, 512, 512,), dtype='float32') - out_labels = out_file.create_dataset('ytrain', (len(paths), vocab_size,), dtype='float32') + out_labels = out_file.create_dataset('ytrain', (len(paths), 1000,), dtype='float32') out_latent = out_file.create_dataset('ztrain', (len(paths), 128,), dtype='float32') out_fns = out_file.create_dataset('fn', (len(paths),), dtype=h5py.string_dtype()) for index, path in enumerate(paths): @@ -82,8 +81,9 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l batch_size = 1 truncation = 1.0 - z_dim = 512 + z_dim = 128 vocab_size = 1000 + img_size = 512 num_channels = 3 z_initial = truncated_z_sample(batch_size, z_dim, truncation/2) @@ -101,7 +101,7 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l 'truncation': input_trunc, }) - target = tf.placeholder(tf.float32, shape=(batch_size, z_dim, z_dim, num_channels)) + target = tf.placeholder(tf.float32, shape=(batch_size, img_size, img_size, num_channels)) # loss = tf.losses.compute_weighted_loss(tf.square(output - target), weights=mask) loss = tf.losses.mean_squared_error(target, output) |
