summaryrefslogtreecommitdiff
path: root/neural_style.py
diff options
context:
space:
mode:
Diffstat (limited to 'neural_style.py')
-rw-r--r--neural_style.py40
1 files changed, 25 insertions, 15 deletions
diff --git a/neural_style.py b/neural_style.py
index bad10d1..7e6e20f 100644
--- a/neural_style.py
+++ b/neural_style.py
@@ -81,7 +81,8 @@ def parse_args():
default=['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1'],
help='VGG19 layers used for the style image. (default: %(default)s)')
- parser.add_argument('--content_layer_weights', type=float,
+ parser.add_argument('--content_layer_weights', nargs='+',
+ type=float,
default=[1.0],
help='Contributions (weights) of each content layer to loss. (default: %(default)s)')
@@ -97,6 +98,11 @@ def parse_args():
choices=['yuv', 'ycrcb', 'luv', 'lab'],
help='Color space for conversion to original colors (default: %(default)s)')
+ parser.add_argument('--color_convert_time', type=str,
+ default='after',
+ choices=['after', 'before'],
+ help='Time (before or after) to convert to original colors (default: %(default)s)')
+
parser.add_argument('--style_mask', action='store_true',
help='Transfer the style to masked regions.')
@@ -529,7 +535,7 @@ def read_flow_file(path):
return flow
def read_weights_file(path):
- lines = open(path).read().splitlines()
+ lines = open(path).readlines()
header = map(int, lines[0].split(' '))
w = header[0]
h = header[1]
@@ -543,7 +549,10 @@ def read_weights_file(path):
return weights
def normalize(weights):
- return [float(i) / sum(weights) for i in weights]
+ denom = sum(weights)
+ if denom > 0.:
+ return [float(i) / denom for i in weights]
+ else: return [0.]
def maybe_make_directory(dir_path):
if not os.path.exists(dir_path):
@@ -639,9 +648,9 @@ def get_optimizer(loss):
return optimizer
def write_video_output(frame, output_img):
- output_frame_fn = args.content_frame_frmt.format(str(frame).zfill(4))
- output_frame_path = os.path.join(args.video_output_dir, output_frame_fn)
- write_image(output_frame_path, output_img)
+ fn = args.content_frame_frmt.format(str(frame).zfill(4))
+ path = os.path.join(args.video_output_dir, fn)
+ write_image(path, output_img)
def write_image_output(output_img, content_img, style_imgs, init_img):
out_dir = os.path.join(args.img_output_dir, args.img_name)
@@ -687,11 +696,11 @@ def write_image_output(output_img, content_img, style_imgs, init_img):
'''
image loading and processing
'''
-def get_init_image(init_type, content_img, style_img, frame=None):
+def get_init_image(init_type, content_img, style_imgs, frame=None):
if init_type == 'content':
return content_img
elif init_type == 'style':
- return style_img
+ return style_imgs[0]
elif init_type == 'random':
init_img = get_noise_image(args.noise_ratio, content_img)
return init_img
@@ -704,9 +713,9 @@ def get_init_image(init_type, content_img, style_img, frame=None):
return init_img
def get_content_frame(frame):
- content_fn = args.content_frame_frmt.format(str(frame).zfill(4))
- content_path = os.path.join(args.video_input_dir, content_fn)
- img = read_image(content_path)
+ fn = args.content_frame_frmt.format(str(frame).zfill(4))
+ path = os.path.join(args.video_input_dir, fn)
+ img = read_image(path)
return img
def get_content_image(content_img):
@@ -751,7 +760,8 @@ def get_mask_image(mask_img, width, height):
path = os.path.join(args.content_img_dir, mask_img)
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
check_image(img, path)
- img = cv2.resize(img, dsize=(width, height), interpolation=cv2.INTER_AREA).astype(np.float32)
+ img = cv2.resize(img, dsize=(width, height), interpolation=cv2.INTER_AREA)
+ img = img.astype(np.float32)
mx = np.amax(img)
img /= mx
return img
@@ -769,9 +779,9 @@ def get_prev_warped_frame(frame):
prev_img = get_prev_frame(frame)
prev_frame = frame - 1
# backwards flow: current frame -> previous frame
- flow_fn = args.backward_optical_flow_frmt.format(str(frame), str(prev_frame))
- flow_path = os.path.join(args.video_input_dir, flow_fn)
- flow = read_flow_file(flow_path)
+ fn = args.backward_optical_flow_frmt.format(str(frame), str(prev_frame))
+ path = os.path.join(args.video_input_dir, fn)
+ flow = read_flow_file(path)
warped_img = warp_image(prev_img, flow).astype(np.float32)
img = preprocess(warped_img, vgg19_mean)
return img