summaryrefslogtreecommitdiff
path: root/inversion
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-19 18:00:40 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-19 18:00:40 +0100
commit61a7dcf6bd99e167809ee9eebefec89fd4f402df (patch)
treee0e82570b25d48e5074fd356dbcda6120aaa4491 /inversion
parent514c89e1d1dbcb1c35de5e7da73b5ac058cb46f4 (diff)
logging
Diffstat (limited to 'inversion')
-rw-r--r--inversion/live.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/inversion/live.py b/inversion/live.py
index 8a84250..a97e17f 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -1,4 +1,5 @@
import os
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import sys
import glob
import h5py
@@ -52,7 +53,7 @@ N_CLASS = input_info['y'].get_shape().as_list()[1]
# Initializers
# --------------------------
-def label_sampler(num_classes=2, shape=(BATCH_SIZE, N_CLASS,)):
+def label_sampler(num_classes=1, shape=(BATCH_SIZE, N_CLASS,)):
label = np.zeros(shape)
for i in range(shape[0]):
for _ in range(random.randint(1, shape[1])):
@@ -109,11 +110,11 @@ class InterpolatorParam:
def assign(self, value):
self.value = value
- def randomize(self, num_classes):
+ def randomize(self, num_classes=1, truncation=1.0):
if self.type == 'noise':
- val = truncated_z_sample(shape=self.shape)
+ val = truncated_z_sample(shape=self.shape, truncation=truncation)
elif self.type == 'label':
- val = label_sampler(shape=self.shape)
+ val = label_sampler(shape=self.shape, num_classes=num_classes)
self.assign(val)
# --------------------------
@@ -159,7 +160,10 @@ class Interpolator:
return opt
def set_value(self, key, value):
- self.opts[key].assign(value)
+ if key in opts:
+ self.opts[key].assign(value)
+ else:
+ sys.stderr.write('{} not a valid option\n'.format(key))
def on_step(self, i, sess):
gen_images = sess.run(self.gen_img, feed_dict=self.get_feed_dict())