summaryrefslogtreecommitdiff
path: root/Code/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'Code/utils.py')
-rw-r--r--Code/utils.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/Code/utils.py b/Code/utils.py
index 6c9f891..39a7e11 100644
--- a/Code/utils.py
+++ b/Code/utils.py
@@ -167,7 +167,7 @@ def psnr_error(gen_frames, gt_frames):
batch.
"""
shape = tf.shape(gen_frames)
- num_pixels = tf.to_float(shape[1] * shape[2])
+ num_pixels = tf.to_float(shape[1] * shape[2] * shape[3])
square_diff = tf.square(gt_frames - gen_frames)
batch_errors = 10 * log10(1 / ((1 / num_pixels) * tf.reduce_sum(square_diff, [1, 2, 3])))
@@ -186,7 +186,7 @@ def sharp_diff_error(gen_frames, gt_frames):
@return: A scalar tensor. The Sharpness Difference error over each frame in the batch.
"""
shape = tf.shape(gen_frames)
- num_pixels = tf.to_float(shape[1] * shape[2])
+ num_pixels = tf.to_float(shape[1] * shape[2] * shape[3])
# gradient difference
# create filters [-1, 1] and [[1],[-1]] for diffing to the left and down respectively.