diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2019-12-19 18:00:40 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2019-12-19 18:00:40 +0100 |
| commit | 61a7dcf6bd99e167809ee9eebefec89fd4f402df (patch) | |
| tree | e0e82570b25d48e5074fd356dbcda6120aaa4491 /inversion/live.py | |
| parent | 514c89e1d1dbcb1c35de5e7da73b5ac058cb46f4 (diff) | |
logging
Diffstat (limited to 'inversion/live.py')
| -rw-r--r-- | inversion/live.py | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/inversion/live.py b/inversion/live.py index 8a84250..a97e17f 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -1,4 +1,5 @@ import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import sys import glob import h5py @@ -52,7 +53,7 @@ N_CLASS = input_info['y'].get_shape().as_list()[1] # Initializers # -------------------------- -def label_sampler(num_classes=2, shape=(BATCH_SIZE, N_CLASS,)): +def label_sampler(num_classes=1, shape=(BATCH_SIZE, N_CLASS,)): label = np.zeros(shape) for i in range(shape[0]): for _ in range(random.randint(1, shape[1])): @@ -109,11 +110,11 @@ class InterpolatorParam: def assign(self, value): self.value = value - def randomize(self, num_classes): + def randomize(self, num_classes=1, truncation=1.0): if self.type == 'noise': - val = truncated_z_sample(shape=self.shape) + val = truncated_z_sample(shape=self.shape, truncation=truncation) elif self.type == 'label': - val = label_sampler(shape=self.shape) + val = label_sampler(shape=self.shape, num_classes=num_classes) self.assign(val) # -------------------------- @@ -159,7 +160,10 @@ class Interpolator: return opt def set_value(self, key, value): - self.opts[key].assign(value) + if key in opts: + self.opts[key].assign(value) + else: + sys.stderr.write('{} not a valid option\n'.format(key)) def on_step(self, i, sess): gen_images = sess.run(self.gen_img, feed_dict=self.get_feed_dict()) |
