diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2020-02-04 18:00:18 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2020-02-04 18:00:18 +0100 |
| commit | 48f798a3a968bb7a73a98c77c309c9f4f8474ba8 (patch) | |
| tree | d8d5f85792599e4cf452af77f22dbf7786c69d5c | |
| parent | 5d0f39e487abb171dd95a50c7994e06bf67f344d (diff) | |
incorporating ssim loss
| -rw-r--r-- | cli/app/search/search_dense.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py index 36fc230..823b91e 100644 --- a/cli/app/search/search_dense.py +++ b/cli/app/search/search_dense.py @@ -171,9 +171,11 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op if params.mse: pix_square_diff = tf.square((target_img - gen_img) / 2.0) mse_loss = tf.reduce_mean(pix_square_diff) + ssim_loss = 1 - tf.image.ssim(im1, im2, max_val=1.0) img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3]) else: mse_loss = tf.constant(0.0) + ssim_loss = tf.constant(0.0) img_mse_err = tf.constant(0.0) # Use custom features for image comparison. @@ -289,7 +291,7 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op + params.lambda_feat * img_feat_err # Batch reconstruction error. - rec_loss = params.lambda_mse * mse_loss + params.lambda_feat * feat_loss + rec_loss = params.lambda_mse * ssim_loss * params.lambda_mse * mse_loss + params.lambda_feat * feat_loss # Total inversion loss. inv_loss = rec_loss @@ -399,8 +401,8 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op for _ in range(params.inv_it): _inv_loss, _mse_loss, _feat_loss,\ - _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, - lrate, inv_train_op]) + _lrate, _ssim_loss, _ = sess.run([inv_loss, mse_loss, feat_loss, + lrate, ssim_loss, inv_train_op]) if params.clipping or params.stochastic_clipping: sess.run(clip_latent) @@ -410,9 +412,9 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op # Log losses. etime = time.time() - start_time print('It [{:8d}] time [{:5.1f}] total [{:.4f}] mse [{:.4f}] ' - 'feat [{:.4f}] ' + 'feat [{:.4f}] ssim [{:.4f}] ' 'lr [{:.4f}]'.format(it, etime, _inv_loss, _mse_loss, - _feat_loss, _lrate)) + _feat_loss, _ssim_loss, _lrate)) sys.stdout.flush() |
