diff options
| -rw-r--r-- | neural_style.py | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/neural_style.py b/neural_style.py index 985ad85..756d19b 100644 --- a/neural_style.py +++ b/neural_style.py @@ -60,7 +60,7 @@ def parse_args(): help='Weight for the style loss function. (default: %(default)s)') parser.add_argument('--tv_weight', type=float, - default=0, + default=1e-3, help='Weight for the transvariational loss function. Set small (e.g. 1e-3). (default: %(default)s)') parser.add_argument('--temporal_weight', type=float, @@ -440,7 +440,8 @@ def get_longterm_weights(i, j): c_max = tf.maximum(c - c_sum, 0.) return c_max -def sum_longterm_temporal_losses(net, frame, x): +def sum_longterm_temporal_losses(sess, net, frame, x): + x = sess.run(net['input'].assign(x)) loss = 0. for j in range(args.prev_frame_indices): prev_frame = frame - j @@ -449,7 +450,8 @@ def sum_longterm_temporal_losses(net, frame, x): loss += temporal_loss(x, w, c) return loss -def sum_shortterm_temporal_losses(net, frame, x): +def sum_shortterm_temporal_losses(sess, net, frame, x): + x = sess.run(net['input'].assign(x)) prev_frame = frame - 1 w = get_prev_warped_frame(frame) c = get_content_weights(frame, prev_frame) @@ -461,8 +463,9 @@ def sum_shortterm_temporal_losses(net, frame, x): remark: not sure this does anything significant. ''' -def sum_total_variation_losses(x): +def sum_total_variation_losses(sess, net, x): b, h, w, d = x.shape + x = sess.run(net['input'].assign(x)) tv_y_size = b * (h-1) * w * d tv_x_size = b * h * (w-1) * d loss_y = tf.nn.l2_loss(x[:,1:,:,:] - x[:,:h-1,:,:]) @@ -556,7 +559,7 @@ def stylize(content_img, style_imgs, init_img, frame=None): L_content = sum_content_losses(sess, net, content_img) # denoising loss - L_tv = sum_total_variation_losses(init_img) + L_tv = sum_total_variation_losses(sess, net, init_img) # loss weights alpha = args.content_weight @@ -570,7 +573,7 @@ def stylize(content_img, style_imgs, init_img, frame=None): if args.video and frame > 1: gamma = args.temporal_weight - L_temporal = sum_shortterm_temporal_losses(sess, frame, init_img) + L_temporal = sum_shortterm_temporal_losses(sess, net, frame, init_img) L_total += gamma * L_temporal # optimization algorithm @@ -584,7 +587,7 @@ def stylize(content_img, style_imgs, init_img, frame=None): output_img = sess.run(net['input']) if args.original_colors: - output_img = convert_to_original_colors(np.copy(content_img), np.copy(output_img)) + output_img = convert_to_original_colors(np.copy(content_img), output_img) if args.video: write_video_output(frame, output_img) @@ -609,7 +612,7 @@ def minimize_with_adam(sess, net, optimizer, init_img, loss): sess.run(train_op) if iterations % args.print_iterations == 0 and args.verbose: curr_loss = loss.eval() - print("At iterate {}\tf= {:.2E}".format(iterations, curr_loss)) + print("At iterate {}\tf= {:.5E}".format(iterations, curr_loss)) iterations += 1 def get_optimizer(loss): |
