summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cli/app/commands/biggan/search_class.py7
-rw-r--r--cli/app/search/vector.py8
2 files changed, 12 insertions, 3 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py
index ab040c0..8ca31d6 100644
--- a/cli/app/commands/biggan/search_class.py
+++ b/cli/app/commands/biggan/search_class.py
@@ -24,14 +24,15 @@ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
from app.search.json import save_params_latent, save_params_dense
from app.search.image import image_to_uint8, imconvert_uint8, imconvert_float32, \
imread, imwrite, imgrid, resize_and_crop_image
-from app.search.vector import truncated_z_sample, truncated_z_single, create_labels
+from app.search.vector import truncated_z_sample, truncated_z_single, \
+ create_labels, create_labels_uniform
@click.command('')
@click.option('-i', '--input', 'opt_fp_in', required=True,
help='Path to input image')
@click.option('-d', '--dims', 'opt_dims', default=512, type=int,
help='Dimensions of BigGAN network (128, 256, 512)')
-@click.option('-s', '--steps', 'opt_steps', default=1000, type=int,
+@click.option('-s', '--steps', 'opt_steps', default=2000, type=int,
help='Number of optimization iterations')
@click.option('-l', '--limit', 'opt_limit', default=1000, type=int,
help='Limit the number of images to process')
@@ -93,7 +94,7 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la
num_channels = 3
z_initial = truncated_z_sample(batch_size, z_dim, truncation/2)
- y_initial = create_labels(batch_size, vocab_size, 10)
+ y_initial = create_labels_uniform(batch_size, vocab_size)
z_lr = 0.001
y_lr = 0.001
diff --git a/cli/app/search/vector.py b/cli/app/search/vector.py
index 89cd949..b118ef3 100644
--- a/cli/app/search/vector.py
+++ b/cli/app/search/vector.py
@@ -18,3 +18,11 @@ def create_labels(batch_size, vocab_size, num_classes):
label[i, j] = random.random()
label[i] /= label[i].sum()
return label
+
+def create_labels_uniform(batch_size, vocab_size):
+ label = np.zeros((batch_size, vocab_size))
+ for i in range(batch_size):
+ for j in range(vocab_size):
+ label[i, j] = random.uniform(1 / vocab_size, 2 / vocab_size)
+ label[i] /= label[i].sum()
+ return label