summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--neural_style.py19
1 files changed, 11 insertions, 8 deletions
diff --git a/neural_style.py b/neural_style.py
index 985ad85..756d19b 100644
--- a/neural_style.py
+++ b/neural_style.py
@@ -60,7 +60,7 @@ def parse_args():
help='Weight for the style loss function. (default: %(default)s)')
parser.add_argument('--tv_weight', type=float,
- default=0,
+ default=1e-3,
help='Weight for the transvariational loss function. Set small (e.g. 1e-3). (default: %(default)s)')
parser.add_argument('--temporal_weight', type=float,
@@ -440,7 +440,8 @@ def get_longterm_weights(i, j):
c_max = tf.maximum(c - c_sum, 0.)
return c_max
-def sum_longterm_temporal_losses(net, frame, x):
+def sum_longterm_temporal_losses(sess, net, frame, x):
+ x = sess.run(net['input'].assign(x))
loss = 0.
for j in range(args.prev_frame_indices):
prev_frame = frame - j
@@ -449,7 +450,8 @@ def sum_longterm_temporal_losses(net, frame, x):
loss += temporal_loss(x, w, c)
return loss
-def sum_shortterm_temporal_losses(net, frame, x):
+def sum_shortterm_temporal_losses(sess, net, frame, x):
+ x = sess.run(net['input'].assign(x))
prev_frame = frame - 1
w = get_prev_warped_frame(frame)
c = get_content_weights(frame, prev_frame)
@@ -461,8 +463,9 @@ def sum_shortterm_temporal_losses(net, frame, x):
remark: not sure this does anything significant.
'''
-def sum_total_variation_losses(x):
+def sum_total_variation_losses(sess, net, x):
b, h, w, d = x.shape
+ x = sess.run(net['input'].assign(x))
tv_y_size = b * (h-1) * w * d
tv_x_size = b * h * (w-1) * d
loss_y = tf.nn.l2_loss(x[:,1:,:,:] - x[:,:h-1,:,:])
@@ -556,7 +559,7 @@ def stylize(content_img, style_imgs, init_img, frame=None):
L_content = sum_content_losses(sess, net, content_img)
# denoising loss
- L_tv = sum_total_variation_losses(init_img)
+ L_tv = sum_total_variation_losses(sess, net, init_img)
# loss weights
alpha = args.content_weight
@@ -570,7 +573,7 @@ def stylize(content_img, style_imgs, init_img, frame=None):
if args.video and frame > 1:
gamma = args.temporal_weight
- L_temporal = sum_shortterm_temporal_losses(sess, frame, init_img)
+ L_temporal = sum_shortterm_temporal_losses(sess, net, frame, init_img)
L_total += gamma * L_temporal
# optimization algorithm
@@ -584,7 +587,7 @@ def stylize(content_img, style_imgs, init_img, frame=None):
output_img = sess.run(net['input'])
if args.original_colors:
- output_img = convert_to_original_colors(np.copy(content_img), np.copy(output_img))
+ output_img = convert_to_original_colors(np.copy(content_img), output_img)
if args.video:
write_video_output(frame, output_img)
@@ -609,7 +612,7 @@ def minimize_with_adam(sess, net, optimizer, init_img, loss):
sess.run(train_op)
if iterations % args.print_iterations == 0 and args.verbose:
curr_loss = loss.eval()
- print("At iterate {}\tf= {:.2E}".format(iterations, curr_loss))
+ print("At iterate {}\tf= {:.5E}".format(iterations, curr_loss))
iterations += 1
def get_optimizer(loss):