diff options
| -rw-r--r-- | cli/app/search/search_dense.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py index 36fc230..823b91e 100644 --- a/cli/app/search/search_dense.py +++ b/cli/app/search/search_dense.py @@ -171,9 +171,11 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op if params.mse: pix_square_diff = tf.square((target_img - gen_img) / 2.0) mse_loss = tf.reduce_mean(pix_square_diff) + ssim_loss = 1 - tf.image.ssim(im1, im2, max_val=1.0) img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3]) else: mse_loss = tf.constant(0.0) + ssim_loss = tf.constant(0.0) img_mse_err = tf.constant(0.0) # Use custom features for image comparison. @@ -289,7 +291,7 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op + params.lambda_feat * img_feat_err # Batch reconstruction error. - rec_loss = params.lambda_mse * mse_loss + params.lambda_feat * feat_loss + rec_loss = params.lambda_mse * ssim_loss * params.lambda_mse * mse_loss + params.lambda_feat * feat_loss # Total inversion loss. inv_loss = rec_loss @@ -399,8 +401,8 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op for _ in range(params.inv_it): _inv_loss, _mse_loss, _feat_loss,\ - _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, - lrate, inv_train_op]) + _lrate, _ssim_loss, _ = sess.run([inv_loss, mse_loss, feat_loss, + lrate, ssim_loss, inv_train_op]) if params.clipping or params.stochastic_clipping: sess.run(clip_latent) @@ -410,9 +412,9 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op # Log losses. etime = time.time() - start_time print('It [{:8d}] time [{:5.1f}] total [{:.4f}] mse [{:.4f}] ' - 'feat [{:.4f}] ' + 'feat [{:.4f}] ssim [{:.4f}] ' 'lr [{:.4f}]'.format(it, etime, _inv_loss, _mse_loss, - _feat_loss, _lrate)) + _feat_loss, _ssim_loss, _lrate)) sys.stdout.flush() |
