summaryrefslogtreecommitdiff
path: root/cli
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-09 09:23:05 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-09 09:23:05 +0100
commit50711c691295304cc94dbaf7b1178e2057ed9b5e (patch)
tree59f3d3caa10c7ab8f736c9bdb4da628a53dfa1ac /cli
parent750f6e26d58a97a65d371839e15924ecaf5b844d (diff)
search
Diffstat (limited to 'cli')
-rw-r--r--cli/app/commands/biggan/search.py57
1 files changed, 13 insertions, 44 deletions
diff --git a/cli/app/commands/biggan/search.py b/cli/app/commands/biggan/search.py
index ec4b0c1..d1e1a0a 100644
--- a/cli/app/commands/biggan/search.py
+++ b/cli/app/commands/biggan/search.py
@@ -104,9 +104,8 @@ def cli(ctx, opt_fp_in, opt_dims):
# module = hub.Module('https://tfhub.dev/deepmind/biggan-256/2')
# module = hub.Module('https://tfhub.dev/deepmind/biggan-512/2')
- inputs = {}
- for k, v in module.get_input_info_dict().items():
- inputs[k] = tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k, trainable=True)
+ inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k)
+ for k, v in module.get_input_info_dict().items()}
input_z = inputs['z']
input_y = inputs['y']
input_trunc = inputs['truncation']
@@ -180,57 +179,27 @@ def cli(ctx, opt_fp_in, opt_dims):
start_im = imgrid(imconvert_uint8(phi_start), cols=5)
imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im)
- cost_op = tf.losses.mean_squared_error(output, phi_target)
- train_op = tf.train.AdamOptimizer(lr_z).minimize(cost_op)
-
try:
for i in range(1000):
feed_dict = {input_z: z, input_y: y, input_trunc: truncation}
-
grad_z, grad_y = sess.run([dc_dz, dc_dy], feed_dict=feed_dict)
-
- #with tf.GradientTape(watch_accessed_variables=False, persistent=True) as g:
- # g.watch(input_z)
- # g.watch(input_y)
- #cost = tf.reduce_sum(tf.pow(output - phi_target, 2))
- #dc_dz = g.gradient(cost, input_z)
- #dc_dy = g.gradient(cost, input_y)
-
- #optimizer.apply_gradients([[dc_dz, input_z], [dc_dy, input_y]])
- #optimizer.apply_gradients([[grad_z, input_z], [grad_y, input_y]])
- print("________")
- #print(z[0][0:10])
- #print(grad_y[0])
z -= grad_z * lr_z
y -= grad_y * lr_y
- # decay/attenuate learning rate to 0.05 of the original over 1000 frames
- if i > 100:
- lr_z *= 0.997
- if i > 500:
- lr_y *= 0.999
+ lr_z *= 0.997
+ lr_y *= 0.999
+
+ if i % 30 == 0:
+ lr_y *= 1.002
+ y = np.clip(y, 0, 1)
+ for j in range(batch_size):
+ y[j] /= y[j].sum()
+ if i > 200 and i % 100 == 0:
+ mean = np.mean(y, axis=0)
+ y = y / 2 + mean / 2
indices = np.logical_or(z <= -2*truncation, z >= +2*truncation)
z[indices] = np.random.randn(np.count_nonzero(indices))
- #print(z[0][0:10])
- if i < 100:
- if i % 30 == 0:
- lr_z *= 1.002
- y = np.clip(y, 0, 1)
- for j in range(batch_size):
- y[j] /= y[j].sum()
- elif i < 300:
- if i % 50 == 0:
- lr_z *= 1.001
- y = np.clip(y, 0, 1)
- for j in range(batch_size):
- y[j] /= y[j].sum()
- elif i < 600:
- if i % 60 == 0:
- y = np.clip(y, 0, 1)
- else:
- if i % 100 == 0:
- y = np.clip(y, 0, 1)
feed_dict = {input_z: z, input_y: y, input_trunc: truncation}
phi_guess = sess.run(output, feed_dict=feed_dict)