diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2020-01-06 18:33:37 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2020-01-06 18:33:37 +0100 |
| commit | 7ea77a044ca9de9d8089bf382640fb4a7bfabc0f (patch) | |
| tree | 7f344845e3346e5f27f7f533b487c74408bab27f /cli/app | |
| parent | 4d0db1188bc94970550c02ff55f16110c5e86700 (diff) | |
create_labels_uniform
Diffstat (limited to 'cli/app')
| -rw-r--r-- | cli/app/commands/biggan/search_class.py | 7 | ||||
| -rw-r--r-- | cli/app/search/vector.py | 8 |
2 files changed, 12 insertions, 3 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py index ab040c0..8ca31d6 100644 --- a/cli/app/commands/biggan/search_class.py +++ b/cli/app/commands/biggan/search_class.py @@ -24,14 +24,15 @@ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) from app.search.json import save_params_latent, save_params_dense from app.search.image import image_to_uint8, imconvert_uint8, imconvert_float32, \ imread, imwrite, imgrid, resize_and_crop_image -from app.search.vector import truncated_z_sample, truncated_z_single, create_labels +from app.search.vector import truncated_z_sample, truncated_z_single, \ + create_labels, create_labels_uniform @click.command('') @click.option('-i', '--input', 'opt_fp_in', required=True, help='Path to input image') @click.option('-d', '--dims', 'opt_dims', default=512, type=int, help='Dimensions of BigGAN network (128, 256, 512)') -@click.option('-s', '--steps', 'opt_steps', default=1000, type=int, +@click.option('-s', '--steps', 'opt_steps', default=2000, type=int, help='Number of optimization iterations') @click.option('-l', '--limit', 'opt_limit', default=1000, type=int, help='Limit the number of images to process') @@ -93,7 +94,7 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la num_channels = 3 z_initial = truncated_z_sample(batch_size, z_dim, truncation/2) - y_initial = create_labels(batch_size, vocab_size, 10) + y_initial = create_labels_uniform(batch_size, vocab_size) z_lr = 0.001 y_lr = 0.001 diff --git a/cli/app/search/vector.py b/cli/app/search/vector.py index 89cd949..b118ef3 100644 --- a/cli/app/search/vector.py +++ b/cli/app/search/vector.py @@ -18,3 +18,11 @@ def create_labels(batch_size, vocab_size, num_classes): label[i, j] = random.random() label[i] /= label[i].sum() return label + +def create_labels_uniform(batch_size, vocab_size): + label = np.zeros((batch_size, vocab_size)) + for i in range(batch_size): + for j in range(vocab_size): + label[i, j] = random.uniform(1 / vocab_size, 2 / vocab_size) + label[i] /= label[i].sum() + return label |
