summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCameron <cysmith1010@gmail.com>2017-04-09 16:44:08 -0700
committerGitHub <noreply@github.com>2017-04-09 16:44:08 -0700
commit2901b7aca6ae43847d6b329609f1ee55a1ce6f3f (patch)
tree93b69002f820a2ec1100ba438cad6ec2de174693
parent18b9314f8c45edcb8a78e33ea8b4acbdeee63908 (diff)
Simplified pre/post-processing
-rw-r--r--neural_style.py30
1 files changed, 14 insertions, 16 deletions
diff --git a/neural_style.py b/neural_style.py
index 13ba513..b717216 100644
--- a/neural_style.py
+++ b/neural_style.py
@@ -230,9 +230,8 @@ def parse_args():
pre-trained vgg19 convolutional neural network
remark: layers are manually initialized for clarity.
'''
-vgg19_mean = np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3))
-def build_vgg19(input_img):
+def build_model(input_img):
if args.verbose: print('\nBUILDING VGG-19 NETWORK')
net = {}
_, h, w, d = input_img.shape
@@ -470,7 +469,6 @@ def sum_shortterm_temporal_losses(sess, net, frame, input_img):
'''
denoising loss function
- remark: not sure this does anything significant.
'''
def sum_total_variation_losses(sess, net, input_img):
b, h, w, d = input_img.shape
@@ -493,23 +491,23 @@ def read_image(path):
img = cv2.imread(path, cv2.IMREAD_COLOR)
check_image(img, path)
img = img.astype(np.float32)
- img = preprocess(img, vgg19_mean)
+ img = preprocess(img)
return img
def write_image(path, img):
- img = postprocess(img, vgg19_mean)
+ img = postprocess(img)
cv2.imwrite(path, img)
-def preprocess(img, mean):
+def preprocess(img):
# bgr to rgb
img = img[...,::-1]
# shape (h, w, d) to (1, h, w, d)
img = img[np.newaxis,:,:,:]
- img -= mean
+ img -= np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3))
return img
-def postprocess(img, mean):
- img += mean
+def postprocess(img):
+ img += np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3))
# shape (1, h, w, d) to (h, w, d)
img = img[0]
img = np.clip(img, 0, 255).astype('uint8')
@@ -565,7 +563,7 @@ def check_image(img, path):
def stylize(content_img, style_imgs, init_img, frame=None):
with tf.device(args.device), tf.Session() as sess:
# setup network
- net = build_vgg19(content_img)
+ net = build_model(content_img)
# style loss
if args.style_mask:
@@ -731,7 +729,7 @@ def get_content_image(content_img):
if w > mx:
h = (float(mx) / float(w)) * h
img = cv2.resize(img, dsize=(mx, int(h)), interpolation=cv2.INTER_AREA)
- img = preprocess(img, vgg19_mean)
+ img = preprocess(img)
return img
def get_style_images(content_img):
@@ -744,7 +742,7 @@ def get_style_images(content_img):
check_image(img, path)
img = img.astype(np.float32)
img = cv2.resize(img, dsize=(cw, ch), interpolation=cv2.INTER_AREA)
- img = preprocess(img, vgg19_mean)
+ img = preprocess(img)
style_imgs.append(img)
return style_imgs
@@ -781,7 +779,7 @@ def get_prev_warped_frame(frame):
path = os.path.join(args.video_input_dir, fn)
flow = read_flow_file(path)
warped_img = warp_image(prev_img, flow).astype(np.float32)
- img = preprocess(warped_img, vgg19_mean)
+ img = preprocess(warped_img)
return img
def get_content_weights(frame, prev_frame):
@@ -807,8 +805,8 @@ def warp_image(src, flow):
return dst
def convert_to_original_colors(content_img, stylized_img):
- content_img = postprocess(content_img, vgg19_mean)
- stylized_img = postprocess(stylized_img, vgg19_mean)
+ content_img = postprocess(content_img)
+ stylized_img = postprocess(stylized_img)
if args.color_convert_type == 'yuv':
cvt_type = cv2.COLOR_BGR2YUV
inv_cvt_type = cv2.COLOR_YUV2BGR
@@ -827,7 +825,7 @@ def convert_to_original_colors(content_img, stylized_img):
_, c2, c3 = cv2.split(content_cvt)
merged = cv2.merge((c1, c2, c3))
dst = cv2.cvtColor(merged, inv_cvt_type).astype(np.float32)
- dst = preprocess(dst, vgg19_mean)
+ dst = preprocess(dst)
return dst
def render_single_image():