summaryrefslogtreecommitdiff
path: root/cli/app/commands/biggan/search.py
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-11 00:15:20 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-11 00:15:20 +0100
commit12e015807c1db187895aeca495b77b367c543211 (patch)
tree3b62da451f7647cfe9e405ec03d368da66af1653 /cli/app/commands/biggan/search.py
parentcc4b6f510c984d70c4943e1f9a06082a0b1df381 (diff)
steps
Diffstat (limited to 'cli/app/commands/biggan/search.py')
-rw-r--r--cli/app/commands/biggan/search.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/cli/app/commands/biggan/search.py b/cli/app/commands/biggan/search.py
index a2e73eb..d2b5900 100644
--- a/cli/app/commands/biggan/search.py
+++ b/cli/app/commands/biggan/search.py
@@ -151,8 +151,9 @@ def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output,
# for j in range(batch_size):
# y[j] /= y[j].sum()
if i > 200 and i % 100 == 0:
+ n = i / opt_steps
mean = np.mean(y, axis=0)
- y = y * 3 / 4 + mean / 4
+ y = y * (1 - n) + mean * n
indices = np.logical_or(z <= -2*truncation, z >= +2*truncation)
z[indices] = np.random.randn(np.count_nonzero(indices))
@@ -166,7 +167,7 @@ def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output,
except KeyboardInterrupt:
pass
print(y.shape)
- out_labels[index] = y
+ out_labels[index] = np.mean(y, axis=0)
return fp_frames
def export_video(fp_frames):