summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-19 21:16:28 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-19 21:16:28 +0100
commitdce014e69e4fc2bd9a110d5a9e7334a9824be215 (patch)
tree0f234609702f9ea8ed1f807285a522e9a34b2a19
parentd9af07ade49673185f5bcfb1a007e9c5a05c59af (diff)
sending command
-rw-r--r--inversion/live.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/inversion/live.py b/inversion/live.py
index 84bfa25..be03422 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -70,6 +70,7 @@ def label_sampler(num_classes=1, shape=(BATCH_SIZE, N_CLASS,)):
for _ in range(random.randint(1, shape[1])):
j = random.randint(0, shape[1]-1)
label[i, j] = random.random()
+ print("class: {} {}".format(j, label[i, j]))
label[i] /= label[i].sum()
return label
@@ -86,7 +87,7 @@ def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)):
class SinParam:
def __init__(self, name, shape, datatype="noise"):
- noise = LerpParam(name + '_noise', shape, datatype=datatype)
+ noise = LerpParam(name + '_noise', shape=shape, datatype=datatype)
orbit_radius = InterpolatorParam(name=name + '_radius', value=0.1)
orbit_speed = InterpolatorParam(name=name + '_speed', value=1.0)
orbit_time = InterpolatorParam(name=name + '_time', value=0.0)
@@ -174,9 +175,9 @@ class Interpolator:
def build(self):
InterpolatorParam(name='truncation', value=1.0),
InterpolatorParam(name='num_classes', value=1.0),
- lerp_z = LerpParam('latent', [BATCH_SIZE, Z_DIM], datatype="noise")
- sin_z = SinParam('orbit', [BATCH_SIZE, Z_DIM], datatype="noise")
- lerp_label = LerpParam('label', [BATCH_SIZE, N_CLASS], datatype="label")
+ lerp_z = LerpParam('latent', shape=[BATCH_SIZE, Z_DIM], datatype="noise")
+ sin_z = SinParam('orbit', shape=[BATCH_SIZE, Z_DIM], datatype="noise")
+ lerp_label = LerpParam('label', shape=[BATCH_SIZE, N_CLASS], datatype="label")
# self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], datatype='noise')
# self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], datatype='label')