diff options
| -rw-r--r-- | inversion/live.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/inversion/live.py b/inversion/live.py index 08cc789..b3e41b1 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -85,7 +85,7 @@ def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)): # -------------------------- class SinParam: - def __init__(self, name, shape, type): + def __init__(self, name, shape, type="noise"): noise = LerpParam(name + '_noise', shape, type=type) orbit_radius = InterpolatorParam(name=name + '_radius', value=0.1) orbit_speed = InterpolatorParam(name=name + '_speed', value=0.1) @@ -168,7 +168,7 @@ class Interpolator: self.opts['num_classes'] = InterpolatorParam(name='num_classes', value=1.0), lerp_z = LerpParam('latent', [BATCH_SIZE, Z_DIM]) sin_z = SinParam('orbit', [BATCH_SIZE, Z_DIM]) - lerp_label = LerpParam('label', [BATCH_SIZE, N_CLASS]) + lerp_label = LerpParam('label', [BATCH_SIZE, N_CLASS], type="label") # self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], type='noise') # self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], type='label') |
