diff options
Diffstat (limited to 'neural_style.py')
| -rw-r--r-- | neural_style.py | 40 |
1 files changed, 25 insertions, 15 deletions
diff --git a/neural_style.py b/neural_style.py index bad10d1..7e6e20f 100644 --- a/neural_style.py +++ b/neural_style.py @@ -81,7 +81,8 @@ def parse_args(): default=['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1'], help='VGG19 layers used for the style image. (default: %(default)s)') - parser.add_argument('--content_layer_weights', type=float, + parser.add_argument('--content_layer_weights', nargs='+', + type=float, default=[1.0], help='Contributions (weights) of each content layer to loss. (default: %(default)s)') @@ -97,6 +98,11 @@ def parse_args(): choices=['yuv', 'ycrcb', 'luv', 'lab'], help='Color space for conversion to original colors (default: %(default)s)') + parser.add_argument('--color_convert_time', type=str, + default='after', + choices=['after', 'before'], + help='Time (before or after) to convert to original colors (default: %(default)s)') + parser.add_argument('--style_mask', action='store_true', help='Transfer the style to masked regions.') @@ -529,7 +535,7 @@ def read_flow_file(path): return flow def read_weights_file(path): - lines = open(path).read().splitlines() + lines = open(path).readlines() header = map(int, lines[0].split(' ')) w = header[0] h = header[1] @@ -543,7 +549,10 @@ def read_weights_file(path): return weights def normalize(weights): - return [float(i) / sum(weights) for i in weights] + denom = sum(weights) + if denom > 0.: + return [float(i) / denom for i in weights] + else: return [0.] def maybe_make_directory(dir_path): if not os.path.exists(dir_path): @@ -639,9 +648,9 @@ def get_optimizer(loss): return optimizer def write_video_output(frame, output_img): - output_frame_fn = args.content_frame_frmt.format(str(frame).zfill(4)) - output_frame_path = os.path.join(args.video_output_dir, output_frame_fn) - write_image(output_frame_path, output_img) + fn = args.content_frame_frmt.format(str(frame).zfill(4)) + path = os.path.join(args.video_output_dir, fn) + write_image(path, output_img) def write_image_output(output_img, content_img, style_imgs, init_img): out_dir = os.path.join(args.img_output_dir, args.img_name) @@ -687,11 +696,11 @@ def write_image_output(output_img, content_img, style_imgs, init_img): ''' image loading and processing ''' -def get_init_image(init_type, content_img, style_img, frame=None): +def get_init_image(init_type, content_img, style_imgs, frame=None): if init_type == 'content': return content_img elif init_type == 'style': - return style_img + return style_imgs[0] elif init_type == 'random': init_img = get_noise_image(args.noise_ratio, content_img) return init_img @@ -704,9 +713,9 @@ def get_init_image(init_type, content_img, style_img, frame=None): return init_img def get_content_frame(frame): - content_fn = args.content_frame_frmt.format(str(frame).zfill(4)) - content_path = os.path.join(args.video_input_dir, content_fn) - img = read_image(content_path) + fn = args.content_frame_frmt.format(str(frame).zfill(4)) + path = os.path.join(args.video_input_dir, fn) + img = read_image(path) return img def get_content_image(content_img): @@ -751,7 +760,8 @@ def get_mask_image(mask_img, width, height): path = os.path.join(args.content_img_dir, mask_img) img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) check_image(img, path) - img = cv2.resize(img, dsize=(width, height), interpolation=cv2.INTER_AREA).astype(np.float32) + img = cv2.resize(img, dsize=(width, height), interpolation=cv2.INTER_AREA) + img = img.astype(np.float32) mx = np.amax(img) img /= mx return img @@ -769,9 +779,9 @@ def get_prev_warped_frame(frame): prev_img = get_prev_frame(frame) prev_frame = frame - 1 # backwards flow: current frame -> previous frame - flow_fn = args.backward_optical_flow_frmt.format(str(frame), str(prev_frame)) - flow_path = os.path.join(args.video_input_dir, flow_fn) - flow = read_flow_file(flow_path) + fn = args.backward_optical_flow_frmt.format(str(frame), str(prev_frame)) + path = os.path.join(args.video_input_dir, fn) + flow = read_flow_file(path) warped_img = warp_image(prev_img, flow).astype(np.float32) img = preprocess(warped_img, vgg19_mean) return img |
