summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCameron <cysmith1010@gmail.com>2017-08-02 14:15:11 +0100
committerGitHub <noreply@github.com>2017-08-02 14:15:11 +0100
commit9e4edc2158ecf5191862a6f2c6b35bd47ac05c17 (patch)
tree7df6fe3a80b39b49f8770ad89ea11653bdd1e6d0
parent4c049592409a3e4cf8519f8ab2577ffc6adaae3b (diff)
Added total variation loss fn
-rw-r--r--neural_style.py20
1 files changed, 2 insertions, 18 deletions
diff --git a/neural_style.py b/neural_style.py
index cbf70a7..6dce589 100644
--- a/neural_style.py
+++ b/neural_style.py
@@ -468,22 +468,6 @@ def sum_shortterm_temporal_losses(sess, net, frame, input_img):
return loss
'''
- denoising loss function
-'''
-def sum_total_variation_losses(sess, net, input_img):
- b, h, w, d = input_img.shape
- x = net['input']
- tv_y_size = b * (h-1) * w * d
- tv_x_size = b * h * (w-1) * d
- loss_y = tf.nn.l2_loss(x[:,1:,:,:] - x[:,:-1,:,:])
- loss_y /= tv_y_size
- loss_x = tf.nn.l2_loss(x[:,:,1:,:] - x[:,:,:-1,:])
- loss_x /= tv_x_size
- loss = 2 * (loss_y + loss_x)
- loss = tf.cast(loss, tf.float32)
- return loss
-
-'''
utilities and i/o
'''
def read_image(path):
@@ -575,7 +559,7 @@ def stylize(content_img, style_imgs, init_img, frame=None):
L_content = sum_content_losses(sess, net, content_img)
# denoising loss
- L_tv = sum_total_variation_losses(sess, net, init_img)
+ L_tv = tf.image.total_variation(net['input'])
# loss weights
alpha = args.content_weight
@@ -869,4 +853,4 @@ def main():
else: render_single_image()
if __name__ == '__main__':
- main() \ No newline at end of file
+ main()