diff options
| author | Cameron <cysmith1010@gmail.com> | 2017-04-09 16:44:08 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-04-09 16:44:08 -0700 |
| commit | 2901b7aca6ae43847d6b329609f1ee55a1ce6f3f (patch) | |
| tree | 93b69002f820a2ec1100ba438cad6ec2de174693 | |
| parent | 18b9314f8c45edcb8a78e33ea8b4acbdeee63908 (diff) | |
Simplified pre/post-processing
| -rw-r--r-- | neural_style.py | 30 |
1 files changed, 14 insertions, 16 deletions
diff --git a/neural_style.py b/neural_style.py index 13ba513..b717216 100644 --- a/neural_style.py +++ b/neural_style.py @@ -230,9 +230,8 @@ def parse_args(): pre-trained vgg19 convolutional neural network remark: layers are manually initialized for clarity. ''' -vgg19_mean = np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3)) -def build_vgg19(input_img): +def build_model(input_img): if args.verbose: print('\nBUILDING VGG-19 NETWORK') net = {} _, h, w, d = input_img.shape @@ -470,7 +469,6 @@ def sum_shortterm_temporal_losses(sess, net, frame, input_img): ''' denoising loss function - remark: not sure this does anything significant. ''' def sum_total_variation_losses(sess, net, input_img): b, h, w, d = input_img.shape @@ -493,23 +491,23 @@ def read_image(path): img = cv2.imread(path, cv2.IMREAD_COLOR) check_image(img, path) img = img.astype(np.float32) - img = preprocess(img, vgg19_mean) + img = preprocess(img) return img def write_image(path, img): - img = postprocess(img, vgg19_mean) + img = postprocess(img) cv2.imwrite(path, img) -def preprocess(img, mean): +def preprocess(img): # bgr to rgb img = img[...,::-1] # shape (h, w, d) to (1, h, w, d) img = img[np.newaxis,:,:,:] - img -= mean + img -= np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3)) return img -def postprocess(img, mean): - img += mean +def postprocess(img): + img += np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3)) # shape (1, h, w, d) to (h, w, d) img = img[0] img = np.clip(img, 0, 255).astype('uint8') @@ -565,7 +563,7 @@ def check_image(img, path): def stylize(content_img, style_imgs, init_img, frame=None): with tf.device(args.device), tf.Session() as sess: # setup network - net = build_vgg19(content_img) + net = build_model(content_img) # style loss if args.style_mask: @@ -731,7 +729,7 @@ def get_content_image(content_img): if w > mx: h = (float(mx) / float(w)) * h img = cv2.resize(img, dsize=(mx, int(h)), interpolation=cv2.INTER_AREA) - img = preprocess(img, vgg19_mean) + img = preprocess(img) return img def get_style_images(content_img): @@ -744,7 +742,7 @@ def get_style_images(content_img): check_image(img, path) img = img.astype(np.float32) img = cv2.resize(img, dsize=(cw, ch), interpolation=cv2.INTER_AREA) - img = preprocess(img, vgg19_mean) + img = preprocess(img) style_imgs.append(img) return style_imgs @@ -781,7 +779,7 @@ def get_prev_warped_frame(frame): path = os.path.join(args.video_input_dir, fn) flow = read_flow_file(path) warped_img = warp_image(prev_img, flow).astype(np.float32) - img = preprocess(warped_img, vgg19_mean) + img = preprocess(warped_img) return img def get_content_weights(frame, prev_frame): @@ -807,8 +805,8 @@ def warp_image(src, flow): return dst def convert_to_original_colors(content_img, stylized_img): - content_img = postprocess(content_img, vgg19_mean) - stylized_img = postprocess(stylized_img, vgg19_mean) + content_img = postprocess(content_img) + stylized_img = postprocess(stylized_img) if args.color_convert_type == 'yuv': cvt_type = cv2.COLOR_BGR2YUV inv_cvt_type = cv2.COLOR_YUV2BGR @@ -827,7 +825,7 @@ def convert_to_original_colors(content_img, stylized_img): _, c2, c3 = cv2.split(content_cvt) merged = cv2.merge((c1, c2, c3)) dst = cv2.cvtColor(merged, inv_cvt_type).astype(np.float32) - dst = preprocess(dst, vgg19_mean) + dst = preprocess(dst) return dst def render_single_image(): |
