diff options
| -rw-r--r-- | neural_style.py | 38 |
1 files changed, 18 insertions, 20 deletions
diff --git a/neural_style.py b/neural_style.py index 13ba513..d5c5eaa 100644 --- a/neural_style.py +++ b/neural_style.py @@ -158,11 +158,11 @@ def parse_args(): help='Boolean flag indicating if the user is generating a video.') parser.add_argument('--start_frame', type=int, - default=1, + default=1, help='First frame number.') parser.add_argument('--end_frame', type=int, - default=1, + default=1, help='Last frame number.') parser.add_argument('--first_frame_type', type=str, @@ -230,9 +230,8 @@ def parse_args(): pre-trained vgg19 convolutional neural network remark: layers are manually initialized for clarity. ''' -vgg19_mean = np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3)) -def build_vgg19(input_img): +def build_model(input_img): if args.verbose: print('\nBUILDING VGG-19 NETWORK') net = {} _, h, w, d = input_img.shape @@ -470,7 +469,6 @@ def sum_shortterm_temporal_losses(sess, net, frame, input_img): ''' denoising loss function - remark: not sure this does anything significant. ''' def sum_total_variation_losses(sess, net, input_img): b, h, w, d = input_img.shape @@ -493,23 +491,23 @@ def read_image(path): img = cv2.imread(path, cv2.IMREAD_COLOR) check_image(img, path) img = img.astype(np.float32) - img = preprocess(img, vgg19_mean) + img = preprocess(img) return img def write_image(path, img): - img = postprocess(img, vgg19_mean) + img = postprocess(img) cv2.imwrite(path, img) -def preprocess(img, mean): +def preprocess(img): # bgr to rgb img = img[...,::-1] # shape (h, w, d) to (1, h, w, d) img = img[np.newaxis,:,:,:] - img -= mean + img -= np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3)) return img -def postprocess(img, mean): - img += mean +def postprocess(img): + img += np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3)) # shape (1, h, w, d) to (h, w, d) img = img[0] img = np.clip(img, 0, 255).astype('uint8') @@ -527,8 +525,8 @@ def read_flow_file(path): flow = np.ndarray((2, h, w), dtype=np.float32) for y in range(h): for x in range(w): - flow[1,y,x] = struct.unpack('f', f.read(4))[0] flow[0,y,x] = struct.unpack('f', f.read(4))[0] + flow[1,y,x] = struct.unpack('f', f.read(4))[0] return flow def read_weights_file(path): @@ -565,7 +563,7 @@ def check_image(img, path): def stylize(content_img, style_imgs, init_img, frame=None): with tf.device(args.device), tf.Session() as sess: # setup network - net = build_vgg19(content_img) + net = build_model(content_img) # style loss if args.style_mask: @@ -731,7 +729,7 @@ def get_content_image(content_img): if w > mx: h = (float(mx) / float(w)) * h img = cv2.resize(img, dsize=(mx, int(h)), interpolation=cv2.INTER_AREA) - img = preprocess(img, vgg19_mean) + img = preprocess(img) return img def get_style_images(content_img): @@ -744,7 +742,7 @@ def get_style_images(content_img): check_image(img, path) img = img.astype(np.float32) img = cv2.resize(img, dsize=(cw, ch), interpolation=cv2.INTER_AREA) - img = preprocess(img, vgg19_mean) + img = preprocess(img) style_imgs.append(img) return style_imgs @@ -781,7 +779,7 @@ def get_prev_warped_frame(frame): path = os.path.join(args.video_input_dir, fn) flow = read_flow_file(path) warped_img = warp_image(prev_img, flow).astype(np.float32) - img = preprocess(warped_img, vgg19_mean) + img = preprocess(warped_img) return img def get_content_weights(frame, prev_frame): @@ -807,8 +805,8 @@ def warp_image(src, flow): return dst def convert_to_original_colors(content_img, stylized_img): - content_img = postprocess(content_img, vgg19_mean) - stylized_img = postprocess(stylized_img, vgg19_mean) + content_img = postprocess(content_img) + stylized_img = postprocess(stylized_img) if args.color_convert_type == 'yuv': cvt_type = cv2.COLOR_BGR2YUV inv_cvt_type = cv2.COLOR_YUV2BGR @@ -827,7 +825,7 @@ def convert_to_original_colors(content_img, stylized_img): _, c2, c3 = cv2.split(content_cvt) merged = cv2.merge((c1, c2, c3)) dst = cv2.cvtColor(merged, inv_cvt_type).astype(np.float32) - dst = preprocess(dst, vgg19_mean) + dst = preprocess(dst) return dst def render_single_image(): @@ -871,4 +869,4 @@ def main(): else: render_single_image() if __name__ == '__main__': - main() +main()
\ No newline at end of file |
