summaryrefslogtreecommitdiff
path: root/neural_style.py
diff options
context:
space:
mode:
Diffstat (limited to 'neural_style.py')
-rw-r--r--neural_style.py20
1 files changed, 2 insertions, 18 deletions
diff --git a/neural_style.py b/neural_style.py
index cbf70a7..6dce589 100644
--- a/neural_style.py
+++ b/neural_style.py
@@ -468,22 +468,6 @@ def sum_shortterm_temporal_losses(sess, net, frame, input_img):
return loss
'''
- denoising loss function
-'''
-def sum_total_variation_losses(sess, net, input_img):
- b, h, w, d = input_img.shape
- x = net['input']
- tv_y_size = b * (h-1) * w * d
- tv_x_size = b * h * (w-1) * d
- loss_y = tf.nn.l2_loss(x[:,1:,:,:] - x[:,:-1,:,:])
- loss_y /= tv_y_size
- loss_x = tf.nn.l2_loss(x[:,:,1:,:] - x[:,:,:-1,:])
- loss_x /= tv_x_size
- loss = 2 * (loss_y + loss_x)
- loss = tf.cast(loss, tf.float32)
- return loss
-
-'''
utilities and i/o
'''
def read_image(path):
@@ -575,7 +559,7 @@ def stylize(content_img, style_imgs, init_img, frame=None):
L_content = sum_content_losses(sess, net, content_img)
# denoising loss
- L_tv = sum_total_variation_losses(sess, net, init_img)
+ L_tv = tf.image.total_variation(net['input'])
# loss weights
alpha = args.content_weight
@@ -869,4 +853,4 @@ def main():
else: render_single_image()
if __name__ == '__main__':
- main() \ No newline at end of file
+ main()