diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2019-12-19 20:26:39 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2019-12-19 20:26:39 +0100 |
| commit | 3e229d4e422dc6abf41028a52f5bfa3e9f4f0c39 (patch) | |
| tree | a4c6d40c19cdcaa349e6ebfb524a2af09001706c /inversion/live.py | |
| parent | 948e20eba7a2e4f5f8a9a853cb2d32f16b2350f5 (diff) | |
type checking
Diffstat (limited to 'inversion/live.py')
| -rw-r--r-- | inversion/live.py | 30 |
1 files changed, 15 insertions, 15 deletions
diff --git a/inversion/live.py b/inversion/live.py index 2dc4970..979e137 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -85,8 +85,8 @@ def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)): # -------------------------- class SinParam: - def __init__(self, name, shape, type="noise"): - noise = LerpParam(name + '_noise', shape, type=type) + def __init__(self, name, shape, datatype="noise"): + noise = LerpParam(name + '_noise', shape, datatype=datatype) orbit_radius = InterpolatorParam(name=name + '_radius', value=0.1) orbit_speed = InterpolatorParam(name=name + '_speed', value=1.0) orbit_time = InterpolatorParam(name=name + '_time', value=0.0) @@ -101,9 +101,9 @@ class SinParam: self.orbit_time.value += self.orbit_speed.value * dt class LerpParam: - def __init__(self, name, shape, type="noise"): - a = InterpolatorParam(name=name + '_a', shape=shape, type=type) - b = InterpolatorParam(name=name + '_b', shape=shape, type=type) + def __init__(self, name, shape, datatype="noise"): + a = InterpolatorParam(name=name + '_a', shape=shape, datatype=datatype) + b = InterpolatorParam(name=name + '_b', shape=shape, datatype=datatype) n = InterpolatorParam(name=name + '_n', value=0.0) speed = InterpolatorParam(name=name + '_speed', value=0.1) output = a.variable * (1 - n.variable) + b.variable * n.variable @@ -135,11 +135,11 @@ class LerpParam: # -------------------------- class InterpolatorParam: - def __init__(self, name, dtype=tf.float32, shape=(), value=None, type=float): + def __init__(self, name, dtype=tf.float32, shape=(), value=None, datatype=float): self.scalar = shape == () self.shape = shape - self.type = type - self.value = (value or 0.0) if type == float else np.zeros(shape) + self.datatype = datatype + self.value = (value or 0.0) if datatype == float else np.zeros(shape) self.variable = tf.placeholder(dtype=dtype, shape=shape) interpolator.opts[name] = self @@ -147,9 +147,9 @@ class InterpolatorParam: self.value = value def randomize(self): - if self.type == 'noise': + if self.datatype == 'noise': val = truncated_z_sample(shape=self.shape, truncation=interpolator.opt['truncation'].value) - elif self.type == 'label': + elif self.datatype == 'label': val = label_sampler(shape=self.shape, num_classes=interpolator.opt['num_classes'].value) self.assign(val) @@ -166,12 +166,12 @@ class Interpolator: def build(self): InterpolatorParam(name='truncation', value=1.0), InterpolatorParam(name='num_classes', value=1.0), - lerp_z = LerpParam('latent', [BATCH_SIZE, Z_DIM]) - sin_z = SinParam('orbit', [BATCH_SIZE, Z_DIM]) - lerp_label = LerpParam('label', [BATCH_SIZE, N_CLASS], type="label") + lerp_z = LerpParam('latent', [BATCH_SIZE, Z_DIM], datatype="noise") + sin_z = SinParam('orbit', [BATCH_SIZE, Z_DIM], datatype="noise") + lerp_label = LerpParam('label', [BATCH_SIZE, N_CLASS], datatype="label") - # self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], type='noise') - # self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], type='label') + # self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], datatype='noise') + # self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], datatype='label') gen_in = {} gen_in['truncation'] = 1.0 # self.opts['truncation'].variable |
