diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2020-01-12 03:32:22 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2020-01-12 03:32:22 +0100 |
| commit | 745ef49ffb5bc2f2eefed349149d7effd1de5f2b (patch) | |
| tree | 683c1ac0a265e20ed493e9527760c667549a343d | |
| parent | 35a7c163bc6928eae9c54f2a3e686ec308dcf03a (diff) | |
disentangled vectors
| -rw-r--r-- | cli/app/search/live.py | 27 | ||||
| -rw-r--r-- | cli/app/settings/app_cfg.py | 1 |
2 files changed, 27 insertions, 1 deletions
diff --git a/cli/app/search/live.py b/cli/app/search/live.py index a1c1700..14f1ad3 100644 --- a/cli/app/search/live.py +++ b/cli/app/search/live.py @@ -21,6 +21,7 @@ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../../../live-cortex/rpc/')) from rpc import CortexRPC from app.search.params import timestamp +from app.utils.tf_utils import read_checkpoint FPS = 25 @@ -83,6 +84,23 @@ def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)): return np.random.normal(size=shape) # -------------------------- +# Disentangled Latents +# -------------------------- + +disentangled = { + 'zoom': read_checkpoint(os.path.join(app_cfg.DISENTANGLED, 'zoom/model.ckpt'), 'walk')[0], + 'shiftx': read_checkpoint(os.path.join(app_cfg.DISENTANGLED, 'shiftx/model.ckpt'), 'walk')[0], + 'shifty': read_checkpoint(os.path.join(app_cfg.DISENTANGLED, 'shifty/model.ckpt'), 'walk')[0], + 'rotate2d': read_checkpoint(os.path.join(app_cfg.DISENTANGLED, 'rotate2d/model.ckpt'), 'walk')[0], + 'rotate3d': read_checkpoint(os.path.join(app_cfg.DISENTANGLED, 'rotate3d/model.ckpt'), 'walk')[0], +} +disentangled_color = read_checkpoint(os.path.join(app_cfg.DISENTANGLED, 'rotate2d/model.ckpt'), 'walk')[0] +disentangled['r'] = disentangled_color[:, 0] +disentangled['g'] = disentangled_color[:, 1] +disentangled['b'] = disentangled_color[:, 2] +disentangled['luminance'] = np.sum(disentangled_color, axis=1) + +# -------------------------- # More complex ops # -------------------------- @@ -209,9 +227,16 @@ class Interpolator: z_abs = z_sum / tf.abs(z_sum) * abs_zoom.variable z_mix = LerpParam('abs_mix', a_in=z_sum, b_in=z_abs, shape=[BATCH_SIZE, Z_DIM], datatype="input") + # Latent - disentangled vectors + zoom = InterpolatorParam(name='zoom', value=0.0).variable * disentangled['zoom'] + shiftx = InterpolatorParam(name='shiftx', value=0.0).variable * disentangled['shiftx'] + shifty = InterpolatorParam(name='shifty', value=0.0).variable * disentangled['shifty'] + luminance = InterpolatorParam(name='luminance', value=0.0).variable * disentangled['luminance'] + disentangled = z_mix.output + zoom + shiftx + shifty + luminance + # Latent - stored vector latent_stored = LerpParam(name='latent_stored', shape=[BATCH_SIZE, Z_DIM], datatype="noise") - latent_stored_mix = LerpParam('latent_stored_mix', a_in=z_mix.output, b_in=latent_stored.output, shape=[BATCH_SIZE, Z_DIM], datatype="input") + latent_stored_mix = LerpParam('latent_stored_mix', a_in=disentangled, b_in=latent_stored.output, shape=[BATCH_SIZE, Z_DIM], datatype="input") # Label lerp_label = LerpParam('label', shape=[BATCH_SIZE, N_CLASS], datatype="label") diff --git a/cli/app/settings/app_cfg.py b/cli/app/settings/app_cfg.py index 149d1cf..b5c383b 100644 --- a/cli/app/settings/app_cfg.py +++ b/cli/app/settings/app_cfg.py @@ -36,6 +36,7 @@ DIR_INPUTS = join(DIR_APP, 'data_store/inputs') DIR_OUTPUTS = join(DIR_APP, 'data_store/outputs') DIR_RESULTS = join(DIR_APP, 'data_store/results') DIR_RENDERS = join(DIR_APP, 'data_store/renders') +DIR_DISENTANGLED = join(DIR_APP, 'data_store/disentangled') FP_MODELZOO = join(DIR_APP, 'modelzoo/modelzoo.yaml') os.makedirs(DIR_INVERSES, exist_ok=True) |
