summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-02-04 18:00:18 +0100
committerJules Laplace <julescarbon@gmail.com>2020-02-04 18:00:18 +0100
commit48f798a3a968bb7a73a98c77c309c9f4f8474ba8 (patch)
treed8d5f85792599e4cf452af77f22dbf7786c69d5c
parent5d0f39e487abb171dd95a50c7994e06bf67f344d (diff)
incorporating ssim loss
-rw-r--r--cli/app/search/search_dense.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py
index 36fc230..823b91e 100644
--- a/cli/app/search/search_dense.py
+++ b/cli/app/search/search_dense.py
@@ -171,9 +171,11 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
if params.mse:
pix_square_diff = tf.square((target_img - gen_img) / 2.0)
mse_loss = tf.reduce_mean(pix_square_diff)
+ ssim_loss = 1 - tf.image.ssim(im1, im2, max_val=1.0)
img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3])
else:
mse_loss = tf.constant(0.0)
+ ssim_loss = tf.constant(0.0)
img_mse_err = tf.constant(0.0)
# Use custom features for image comparison.
@@ -289,7 +291,7 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
+ params.lambda_feat * img_feat_err
# Batch reconstruction error.
- rec_loss = params.lambda_mse * mse_loss + params.lambda_feat * feat_loss
+ rec_loss = params.lambda_mse * ssim_loss * params.lambda_mse * mse_loss + params.lambda_feat * feat_loss
# Total inversion loss.
inv_loss = rec_loss
@@ -399,8 +401,8 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
for _ in range(params.inv_it):
_inv_loss, _mse_loss, _feat_loss,\
- _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss,
- lrate, inv_train_op])
+ _lrate, _ssim_loss, _ = sess.run([inv_loss, mse_loss, feat_loss,
+ lrate, ssim_loss, inv_train_op])
if params.clipping or params.stochastic_clipping:
sess.run(clip_latent)
@@ -410,9 +412,9 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
# Log losses.
etime = time.time() - start_time
print('It [{:8d}] time [{:5.1f}] total [{:.4f}] mse [{:.4f}] '
- 'feat [{:.4f}] '
+ 'feat [{:.4f}] ssim [{:.4f}] '
'lr [{:.4f}]'.format(it, etime, _inv_loss, _mse_loss,
- _feat_loss, _lrate))
+ _feat_loss, _ssim_loss, _lrate))
sys.stdout.flush()