summaryrefslogtreecommitdiff
path: root/neural_style.py
diff options
context:
space:
mode:
Diffstat (limited to 'neural_style.py')
-rw-r--r--neural_style.py79
1 files changed, 40 insertions, 39 deletions
diff --git a/neural_style.py b/neural_style.py
index 565dbc0..374b4c1 100644
--- a/neural_style.py
+++ b/neural_style.py
@@ -4,8 +4,7 @@ import scipy.io
import argparse
import struct
import time
-import cv2
-import csv
+import cv2
import os
'''
@@ -61,7 +60,7 @@ def parse_args():
help='Weight for the style loss function. (default: %(default)s)')
parser.add_argument('--tv_weight', type=float,
- default=0,
+ default=1e-3,
help='Weight for the transvariational loss function. Set small (e.g. 1e-3). (default: %(default)s)')
parser.add_argument('--temporal_weight', type=float,
@@ -88,9 +87,7 @@ def parse_args():
parser.add_argument('--style_layer_weights', nargs='+', type=float,
default=[0.2, 0.2, 0.2, 0.2, 0.2],
help='Contributions (weights) of each style layer to loss. (default: %(default)s)')
-
- parser.add_argument('--style_scale', type=float, default=1.0)
-
+
parser.add_argument('--original_colors', action='store_true',
help='Transfer the style but not the colors.')
@@ -345,17 +342,12 @@ def content_layer_loss(p, x):
if args.content_loss_function == 1:
K = 1. / (2 * N**0.5 * M**0.5)
elif args.content_loss_function == 2:
- K = 1. / 2.
+ K = 1. / (N * M)
elif args.content_loss_function == 3:
- K = 1. / (N * M)
+ K = 1. / 2.
loss = K * tf.reduce_sum(tf.pow((x - p), 2))
return loss
-def gram_matrix(x, area, depth):
- F = tf.reshape(x[0], (area, depth))
- G = tf.matmul(tf.transpose(F), F)
- return G
-
def style_layer_loss(a, x):
_, h, w, d = a.get_shape()
M = h.value * w.value
@@ -365,6 +357,11 @@ def style_layer_loss(a, x):
loss = (1./(4 * N**2 * M**2)) * tf.reduce_sum(tf.pow((G - A), 2))
return loss
+def gram_matrix(x, area, depth):
+ F = tf.reshape(x[0], (area, depth))
+ G = tf.matmul(tf.transpose(F), F)
+ return G
+
def mask_style_layer(a, x, mask_img):
_, h, w, d = a.get_shape()
mask = get_mask_image(mask_img, w.value, h.value)
@@ -443,7 +440,8 @@ def get_longterm_weights(i, j):
c_max = tf.maximum(c - c_sum, 0.)
return c_max
-def sum_longterm_temporal_losses(net, frame, x):
+def sum_longterm_temporal_losses(sess, net, frame, input_img):
+ x = sess.run(net['input'].assign(input_img))
loss = 0.
for j in range(args.prev_frame_indices):
prev_frame = frame - j
@@ -452,7 +450,8 @@ def sum_longterm_temporal_losses(net, frame, x):
loss += temporal_loss(x, w, c)
return loss
-def sum_shortterm_temporal_losses(net, frame, x):
+def sum_shortterm_temporal_losses(sess, net, frame, input_img):
+ x = sess.run(net['input'].assign(input_img))
prev_frame = frame - 1
w = get_prev_warped_frame(frame)
c = get_content_weights(frame, prev_frame)
@@ -462,10 +461,11 @@ def sum_shortterm_temporal_losses(net, frame, x):
'''
denoising loss function
- remark: not convinced this does anything significant.
+ remark: not sure this does anything significant.
'''
-def sum_total_variation_losses(x):
- b, h, w, d = x.shape
+def sum_total_variation_losses(sess, net, input_img):
+ b, h, w, d = input_img.shape
+ x = sess.run(net['input'].assign(input_img))
tv_y_size = b * (h-1) * w * d
tv_x_size = b * h * (w-1) * d
loss_y = tf.nn.l2_loss(x[:,1:,:,:] - x[:,:h-1,:,:])
@@ -480,7 +480,7 @@ def sum_total_variation_losses(x):
utilities and i/o
'''
def read_image(path):
- # BGR image
+ # bgr image
img = cv2.imread(path, cv2.IMREAD_COLOR).astype('float')
img = preprocess(img, vgg19_mean)
return img
@@ -490,21 +490,19 @@ def write_image(path, img):
cv2.imwrite(path, img)
def preprocess(img, mean):
- # BGR to RGB
+ # bgr to rgb
img = img[...,::-1]
# shape (h, w, d) to (1, h, w, d)
img = img[np.newaxis,:,:,:]
- # subtract mean
img -= mean
return img
def postprocess(img, mean):
- # add mean
img += mean
# shape (1, h, w, d) to (h, w, d)
img = img[0]
img = np.clip(img, 0, 255).astype('uint8')
- # RGB to BGR
+ # rgb to bgr
img = img[...,::-1]
return img
@@ -561,7 +559,7 @@ def stylize(content_img, style_imgs, init_img, frame=None):
L_content = sum_content_losses(sess, net, content_img)
# denoising loss
- L_tv = sum_total_variation_losses(init_img)
+ L_tv = sum_total_variation_losses(sess, net, init_img)
# loss weights
alpha = args.content_weight
@@ -575,7 +573,7 @@ def stylize(content_img, style_imgs, init_img, frame=None):
if args.video and frame > 1:
gamma = args.temporal_weight
- L_temporal = sum_shortterm_temporal_losses(sess, frame, init_img)
+ L_temporal = sum_shortterm_temporal_losses(sess, net, frame, init_img)
L_total += gamma * L_temporal
# optimization algorithm
@@ -589,7 +587,7 @@ def stylize(content_img, style_imgs, init_img, frame=None):
output_img = sess.run(net['input'])
if args.original_colors:
- output_img = convert_to_original_colors(np.copy(content_img), np.copy(output_img))
+ output_img = convert_to_original_colors(np.copy(content_img), output_img)
if args.video:
write_video_output(frame, output_img)
@@ -597,14 +595,14 @@ def stylize(content_img, style_imgs, init_img, frame=None):
write_image_output(output_img, content_img, style_imgs, init_img)
def minimize_with_lbfgs(sess, net, optimizer, init_img):
- if args.verbose: print('MINIMIZING LOSS USING: L-BFGS OPTIMIZER')
+ if args.verbose: print('\nMINIMIZING LOSS USING: L-BFGS OPTIMIZER')
init_op = tf.initialize_all_variables()
sess.run(init_op)
sess.run(net['input'].assign(init_img))
optimizer.minimize(sess)
def minimize_with_adam(sess, net, optimizer, init_img, loss):
- if args.verbose: print('MINIMIZING LOSS USING: ADAM OPTIMIZER')
+ if args.verbose: print('\nMINIMIZING LOSS USING: ADAM OPTIMIZER')
train_op = optimizer.minimize(loss)
init_op = tf.initialize_all_variables()
sess.run(init_op)
@@ -612,6 +610,9 @@ def minimize_with_adam(sess, net, optimizer, init_img, loss):
iterations = 0
while (iterations < args.max_iterations):
sess.run(train_op)
+ if iterations % args.print_iterations == 0 and args.verbose:
+ curr_loss = loss.eval()
+ print("At iterate {}\tf= {:.5E}".format(iterations, curr_loss))
iterations += 1
def get_optimizer(loss):
@@ -650,16 +651,16 @@ def write_image_output(output_img, content_img, style_imgs, init_img):
# save the configuration settings
out_file = os.path.join(out_dir, 'meta_data.txt')
f = open(out_file, 'w')
- f.write('image name: {}\n'.format(args.img_name))
+ f.write('image_name: {}\n'.format(args.img_name))
f.write('content: {}\n'.format(args.content_img))
index = 0
for style_img, weight in zip(args.style_imgs, args.style_imgs_weights):
- f.write('styles ['+str(index)+']: {} * {}\n'.format(weight, style_img))
+ f.write('styles['+str(index)+']: {} * {}\n'.format(weight, style_img))
index += 1
index = 0
if args.style_mask_imgs is not None:
for mask in args.style_mask_imgs:
- f.write('style masks ['+str(index)+']: {}\n'.format(mask))
+ f.write('style_masks['+str(index)+']: {}\n'.format(mask))
index += 1
f.write('init_type: {}\n'.format(args.init_img_type))
f.write('content_weight: {}\n'.format(args.content_weight))
@@ -698,8 +699,8 @@ def get_content_frame(frame):
return img
def get_content_image(content_img):
- # BGR image
path = os.path.join(args.content_img_dir, content_img)
+ # bgr image
img = cv2.imread(path, cv2.IMREAD_COLOR).astype('float')
h, w, d = img.shape
mx = args.max_size
@@ -713,14 +714,14 @@ def get_content_image(content_img):
img = preprocess(img, vgg19_mean)
return img
-def get_style_images(content_img, scale):
+def get_style_images(content_img):
+ _, ch, cw, cd = content_img.shape
style_imgs = []
for style_fn in args.style_imgs:
path = os.path.join(args.style_imgs_dir, style_fn)
- # BGR image
+ # bgr image
img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32)
- _, h, w, d = content_img.shape
- img = cv2.resize(img, dsize=(int(w*scale), int(h*scale)))
+ img = cv2.resize(img, dsize=(cw, ch))
img = preprocess(img, vgg19_mean)
style_imgs.append(img)
return style_imgs
@@ -796,7 +797,7 @@ def convert_to_original_colors(content_img, stylized_img):
def render_single_image():
content_img = get_content_image(args.content_img)
- style_imgs = get_style_images(content_img, args.style_scale)
+ style_imgs = get_style_images(content_img)
with tf.Graph().as_default():
print('\n---- RENDERING SINGLE IMAGE ----\n')
init_img = get_init_image(args.init_img_type, content_img, style_imgs)
@@ -811,7 +812,7 @@ def render_video():
print('\n---- RENDERING VIDEO FRAME: {}/{} ----\n'.format(frame, args.end_frame))
if frame == 1:
content_frame = get_content_frame(frame)
- style_imgs = get_style_images(content_frame, args.style_scale)
+ style_imgs = get_style_images(content_frame)
init_img = get_init_image(args.first_frame_type, content_frame, style_imgs, frame)
args.max_iterations = args.first_frame_iterations
tick = time.time()
@@ -820,7 +821,7 @@ def render_video():
print('Frame {} elapsed time: {}'.format(frame, tock - tick))
else:
content_frame = get_content_frame(frame)
- style_imgs = get_style_images(content_frame, args.style_scale)
+ style_imgs = get_style_images(content_frame)
init_img = get_init_image(args.init_frame_type, content_frame, style_imgs, frame)
args.max_iterations = args.frame_iterations
tick = time.time()