diff options
| -rw-r--r-- | inversion/live.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/inversion/live.py b/inversion/live.py index 711df22..84bfa25 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -139,7 +139,10 @@ class InterpolatorParam: self.scalar = shape == () self.shape = shape self.datatype = datatype - self.value = (value or 0.0) if datatype == "float" else np.zeros(shape) + if datatype == "float": + self.assign(value or 0.0) + else: + self.randomize() self.variable = tf.placeholder(dtype=dtype, shape=shape) interpolator.opts[name] = self @@ -151,9 +154,11 @@ class InterpolatorParam: def randomize(self): if self.datatype == 'noise': - val = truncated_z_sample(shape=self.shape, truncation=interpolator.opt['truncation'].value) + val = truncated_z_sample(shape=self.shape, truncation=interpolator.opts['truncation'].value) elif self.datatype == 'label': - val = label_sampler(shape=self.shape, num_classes=interpolator.opt['num_classes'].value) + val = label_sampler(shape=self.shape, num_classes=interpolator.opts['num_classes'].value) + else: + val = 0.0 self.assign(val) # -------------------------- |
