summaryrefslogtreecommitdiff
path: root/cli
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-01-12 03:32:22 +0100
committerJules Laplace <julescarbon@gmail.com>2020-01-12 03:32:22 +0100
commit745ef49ffb5bc2f2eefed349149d7effd1de5f2b (patch)
tree683c1ac0a265e20ed493e9527760c667549a343d /cli
parent35a7c163bc6928eae9c54f2a3e686ec308dcf03a (diff)
disentangled vectors
Diffstat (limited to 'cli')
-rw-r--r--cli/app/search/live.py27
-rw-r--r--cli/app/settings/app_cfg.py1
2 files changed, 27 insertions, 1 deletions
diff --git a/cli/app/search/live.py b/cli/app/search/live.py
index a1c1700..14f1ad3 100644
--- a/cli/app/search/live.py
+++ b/cli/app/search/live.py
@@ -21,6 +21,7 @@ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../../../live-cortex/rpc/'))
from rpc import CortexRPC
from app.search.params import timestamp
+from app.utils.tf_utils import read_checkpoint
FPS = 25
@@ -83,6 +84,23 @@ def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)):
return np.random.normal(size=shape)
# --------------------------
+# Disentangled Latents
+# --------------------------
+
+disentangled = {
+ 'zoom': read_checkpoint(os.path.join(app_cfg.DISENTANGLED, 'zoom/model.ckpt'), 'walk')[0],
+ 'shiftx': read_checkpoint(os.path.join(app_cfg.DISENTANGLED, 'shiftx/model.ckpt'), 'walk')[0],
+ 'shifty': read_checkpoint(os.path.join(app_cfg.DISENTANGLED, 'shifty/model.ckpt'), 'walk')[0],
+ 'rotate2d': read_checkpoint(os.path.join(app_cfg.DISENTANGLED, 'rotate2d/model.ckpt'), 'walk')[0],
+ 'rotate3d': read_checkpoint(os.path.join(app_cfg.DISENTANGLED, 'rotate3d/model.ckpt'), 'walk')[0],
+}
+disentangled_color = read_checkpoint(os.path.join(app_cfg.DISENTANGLED, 'rotate2d/model.ckpt'), 'walk')[0]
+disentangled['r'] = disentangled_color[:, 0]
+disentangled['g'] = disentangled_color[:, 1]
+disentangled['b'] = disentangled_color[:, 2]
+disentangled['luminance'] = np.sum(disentangled_color, axis=1)
+
+# --------------------------
# More complex ops
# --------------------------
@@ -209,9 +227,16 @@ class Interpolator:
z_abs = z_sum / tf.abs(z_sum) * abs_zoom.variable
z_mix = LerpParam('abs_mix', a_in=z_sum, b_in=z_abs, shape=[BATCH_SIZE, Z_DIM], datatype="input")
+ # Latent - disentangled vectors
+ zoom = InterpolatorParam(name='zoom', value=0.0).variable * disentangled['zoom']
+ shiftx = InterpolatorParam(name='shiftx', value=0.0).variable * disentangled['shiftx']
+ shifty = InterpolatorParam(name='shifty', value=0.0).variable * disentangled['shifty']
+ luminance = InterpolatorParam(name='luminance', value=0.0).variable * disentangled['luminance']
+ disentangled = z_mix.output + zoom + shiftx + shifty + luminance
+
# Latent - stored vector
latent_stored = LerpParam(name='latent_stored', shape=[BATCH_SIZE, Z_DIM], datatype="noise")
- latent_stored_mix = LerpParam('latent_stored_mix', a_in=z_mix.output, b_in=latent_stored.output, shape=[BATCH_SIZE, Z_DIM], datatype="input")
+ latent_stored_mix = LerpParam('latent_stored_mix', a_in=disentangled, b_in=latent_stored.output, shape=[BATCH_SIZE, Z_DIM], datatype="input")
# Label
lerp_label = LerpParam('label', shape=[BATCH_SIZE, N_CLASS], datatype="label")
diff --git a/cli/app/settings/app_cfg.py b/cli/app/settings/app_cfg.py
index 149d1cf..b5c383b 100644
--- a/cli/app/settings/app_cfg.py
+++ b/cli/app/settings/app_cfg.py
@@ -36,6 +36,7 @@ DIR_INPUTS = join(DIR_APP, 'data_store/inputs')
DIR_OUTPUTS = join(DIR_APP, 'data_store/outputs')
DIR_RESULTS = join(DIR_APP, 'data_store/results')
DIR_RENDERS = join(DIR_APP, 'data_store/renders')
+DIR_DISENTANGLED = join(DIR_APP, 'data_store/disentangled')
FP_MODELZOO = join(DIR_APP, 'modelzoo/modelzoo.yaml')
os.makedirs(DIR_INVERSES, exist_ok=True)