diff options
| -rw-r--r-- | cli/app/commands/biggan/search.py | 57 |
1 files changed, 13 insertions, 44 deletions
diff --git a/cli/app/commands/biggan/search.py b/cli/app/commands/biggan/search.py index ec4b0c1..d1e1a0a 100644 --- a/cli/app/commands/biggan/search.py +++ b/cli/app/commands/biggan/search.py @@ -104,9 +104,8 @@ def cli(ctx, opt_fp_in, opt_dims): # module = hub.Module('https://tfhub.dev/deepmind/biggan-256/2') # module = hub.Module('https://tfhub.dev/deepmind/biggan-512/2') - inputs = {} - for k, v in module.get_input_info_dict().items(): - inputs[k] = tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k, trainable=True) + inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k) + for k, v in module.get_input_info_dict().items()} input_z = inputs['z'] input_y = inputs['y'] input_trunc = inputs['truncation'] @@ -180,57 +179,27 @@ def cli(ctx, opt_fp_in, opt_dims): start_im = imgrid(imconvert_uint8(phi_start), cols=5) imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im) - cost_op = tf.losses.mean_squared_error(output, phi_target) - train_op = tf.train.AdamOptimizer(lr_z).minimize(cost_op) - try: for i in range(1000): feed_dict = {input_z: z, input_y: y, input_trunc: truncation} - grad_z, grad_y = sess.run([dc_dz, dc_dy], feed_dict=feed_dict) - - #with tf.GradientTape(watch_accessed_variables=False, persistent=True) as g: - # g.watch(input_z) - # g.watch(input_y) - #cost = tf.reduce_sum(tf.pow(output - phi_target, 2)) - #dc_dz = g.gradient(cost, input_z) - #dc_dy = g.gradient(cost, input_y) - - #optimizer.apply_gradients([[dc_dz, input_z], [dc_dy, input_y]]) - #optimizer.apply_gradients([[grad_z, input_z], [grad_y, input_y]]) - print("________") - #print(z[0][0:10]) - #print(grad_y[0]) z -= grad_z * lr_z y -= grad_y * lr_y - # decay/attenuate learning rate to 0.05 of the original over 1000 frames - if i > 100: - lr_z *= 0.997 - if i > 500: - lr_y *= 0.999 + lr_z *= 0.997 + lr_y *= 0.999 + + if i % 30 == 0: + lr_y *= 1.002 + y = np.clip(y, 0, 1) + for j in range(batch_size): + y[j] /= y[j].sum() + if i > 200 and i % 100 == 0: + mean = np.mean(y, axis=0) + y = y / 2 + mean / 2 indices = np.logical_or(z <= -2*truncation, z >= +2*truncation) z[indices] = np.random.randn(np.count_nonzero(indices)) - #print(z[0][0:10]) - if i < 100: - if i % 30 == 0: - lr_z *= 1.002 - y = np.clip(y, 0, 1) - for j in range(batch_size): - y[j] /= y[j].sum() - elif i < 300: - if i % 50 == 0: - lr_z *= 1.001 - y = np.clip(y, 0, 1) - for j in range(batch_size): - y[j] /= y[j].sum() - elif i < 600: - if i % 60 == 0: - y = np.clip(y, 0, 1) - else: - if i % 100 == 0: - y = np.clip(y, 0, 1) feed_dict = {input_z: z, input_y: y, input_trunc: truncation} phi_guess = sess.run(output, feed_dict=feed_dict) |
