diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2019-12-19 21:16:28 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2019-12-19 21:16:28 +0100 |
| commit | dce014e69e4fc2bd9a110d5a9e7334a9824be215 (patch) | |
| tree | 0f234609702f9ea8ed1f807285a522e9a34b2a19 /inversion/live.py | |
| parent | d9af07ade49673185f5bcfb1a007e9c5a05c59af (diff) | |
sending command
Diffstat (limited to 'inversion/live.py')
| -rw-r--r-- | inversion/live.py | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/inversion/live.py b/inversion/live.py index 84bfa25..be03422 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -70,6 +70,7 @@ def label_sampler(num_classes=1, shape=(BATCH_SIZE, N_CLASS,)): for _ in range(random.randint(1, shape[1])): j = random.randint(0, shape[1]-1) label[i, j] = random.random() + print("class: {} {}".format(j, label[i, j])) label[i] /= label[i].sum() return label @@ -86,7 +87,7 @@ def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)): class SinParam: def __init__(self, name, shape, datatype="noise"): - noise = LerpParam(name + '_noise', shape, datatype=datatype) + noise = LerpParam(name + '_noise', shape=shape, datatype=datatype) orbit_radius = InterpolatorParam(name=name + '_radius', value=0.1) orbit_speed = InterpolatorParam(name=name + '_speed', value=1.0) orbit_time = InterpolatorParam(name=name + '_time', value=0.0) @@ -174,9 +175,9 @@ class Interpolator: def build(self): InterpolatorParam(name='truncation', value=1.0), InterpolatorParam(name='num_classes', value=1.0), - lerp_z = LerpParam('latent', [BATCH_SIZE, Z_DIM], datatype="noise") - sin_z = SinParam('orbit', [BATCH_SIZE, Z_DIM], datatype="noise") - lerp_label = LerpParam('label', [BATCH_SIZE, N_CLASS], datatype="label") + lerp_z = LerpParam('latent', shape=[BATCH_SIZE, Z_DIM], datatype="noise") + sin_z = SinParam('orbit', shape=[BATCH_SIZE, Z_DIM], datatype="noise") + lerp_label = LerpParam('label', shape=[BATCH_SIZE, N_CLASS], datatype="label") # self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], datatype='noise') # self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], datatype='label') |
