From 9e4edc2158ecf5191862a6f2c6b35bd47ac05c17 Mon Sep 17 00:00:00 2001 From: Cameron Date: Wed, 2 Aug 2017 14:15:11 +0100 Subject: Added total variation loss fn --- neural_style.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/neural_style.py b/neural_style.py index cbf70a7..6dce589 100644 --- a/neural_style.py +++ b/neural_style.py @@ -467,22 +467,6 @@ def sum_shortterm_temporal_losses(sess, net, frame, input_img): loss = temporal_loss(x, w, c) return loss -''' - denoising loss function -''' -def sum_total_variation_losses(sess, net, input_img): - b, h, w, d = input_img.shape - x = net['input'] - tv_y_size = b * (h-1) * w * d - tv_x_size = b * h * (w-1) * d - loss_y = tf.nn.l2_loss(x[:,1:,:,:] - x[:,:-1,:,:]) - loss_y /= tv_y_size - loss_x = tf.nn.l2_loss(x[:,:,1:,:] - x[:,:,:-1,:]) - loss_x /= tv_x_size - loss = 2 * (loss_y + loss_x) - loss = tf.cast(loss, tf.float32) - return loss - ''' utilities and i/o ''' @@ -575,7 +559,7 @@ def stylize(content_img, style_imgs, init_img, frame=None): L_content = sum_content_losses(sess, net, content_img) # denoising loss - L_tv = sum_total_variation_losses(sess, net, init_img) + L_tv = tf.image.total_variation(net['input']) # loss weights alpha = args.content_weight @@ -869,4 +853,4 @@ def main(): else: render_single_image() if __name__ == '__main__': - main() \ No newline at end of file + main() -- cgit v1.2.3-70-g09d2