diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2019-12-19 20:54:02 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2019-12-19 20:54:02 +0100 |
| commit | d9af07ade49673185f5bcfb1a007e9c5a05c59af (patch) | |
| tree | 863957f5c7133471e95c783ff495debd67647510 /inversion/live.py | |
| parent | 2a2c7c5ff57c13afd2d2fd6af9525cf0cb5b151e (diff) | |
sending command
Diffstat (limited to 'inversion/live.py')
| -rw-r--r-- | inversion/live.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/inversion/live.py b/inversion/live.py index 711df22..84bfa25 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -139,7 +139,10 @@ class InterpolatorParam: self.scalar = shape == () self.shape = shape self.datatype = datatype - self.value = (value or 0.0) if datatype == "float" else np.zeros(shape) + if datatype == "float": + self.assign(value or 0.0) + else: + self.randomize() self.variable = tf.placeholder(dtype=dtype, shape=shape) interpolator.opts[name] = self @@ -151,9 +154,11 @@ class InterpolatorParam: def randomize(self): if self.datatype == 'noise': - val = truncated_z_sample(shape=self.shape, truncation=interpolator.opt['truncation'].value) + val = truncated_z_sample(shape=self.shape, truncation=interpolator.opts['truncation'].value) elif self.datatype == 'label': - val = label_sampler(shape=self.shape, num_classes=interpolator.opt['num_classes'].value) + val = label_sampler(shape=self.shape, num_classes=interpolator.opts['num_classes'].value) + else: + val = 0.0 self.assign(val) # -------------------------- |
