summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorcam <cameron@ideum.com>2017-07-03 19:00:55 -0700
committercam <cameron@ideum.com>2017-07-03 19:00:55 -0700
commit4d756ad9e52140fa207bcfaf7bfd376cf4a06679 (patch)
treea946ce38b8a95897dc88038690b6f9e3cb007871
parent6840bd72571b7e18b76f901891648fca5ff7352a (diff)
"Video bug"
-rw-r--r--neural_style.py38
1 files changed, 18 insertions, 20 deletions
diff --git a/neural_style.py b/neural_style.py
index 13ba513..d5c5eaa 100644
--- a/neural_style.py
+++ b/neural_style.py
@@ -158,11 +158,11 @@ def parse_args():
help='Boolean flag indicating if the user is generating a video.')
parser.add_argument('--start_frame', type=int,
- default=1,
+ default=1,
help='First frame number.')
parser.add_argument('--end_frame', type=int,
- default=1,
+ default=1,
help='Last frame number.')
parser.add_argument('--first_frame_type', type=str,
@@ -230,9 +230,8 @@ def parse_args():
pre-trained vgg19 convolutional neural network
remark: layers are manually initialized for clarity.
'''
-vgg19_mean = np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3))
-def build_vgg19(input_img):
+def build_model(input_img):
if args.verbose: print('\nBUILDING VGG-19 NETWORK')
net = {}
_, h, w, d = input_img.shape
@@ -470,7 +469,6 @@ def sum_shortterm_temporal_losses(sess, net, frame, input_img):
'''
denoising loss function
- remark: not sure this does anything significant.
'''
def sum_total_variation_losses(sess, net, input_img):
b, h, w, d = input_img.shape
@@ -493,23 +491,23 @@ def read_image(path):
img = cv2.imread(path, cv2.IMREAD_COLOR)
check_image(img, path)
img = img.astype(np.float32)
- img = preprocess(img, vgg19_mean)
+ img = preprocess(img)
return img
def write_image(path, img):
- img = postprocess(img, vgg19_mean)
+ img = postprocess(img)
cv2.imwrite(path, img)
-def preprocess(img, mean):
+def preprocess(img):
# bgr to rgb
img = img[...,::-1]
# shape (h, w, d) to (1, h, w, d)
img = img[np.newaxis,:,:,:]
- img -= mean
+ img -= np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3))
return img
-def postprocess(img, mean):
- img += mean
+def postprocess(img):
+ img += np.array([123.68, 116.779, 103.939]).reshape((1,1,1,3))
# shape (1, h, w, d) to (h, w, d)
img = img[0]
img = np.clip(img, 0, 255).astype('uint8')
@@ -527,8 +525,8 @@ def read_flow_file(path):
flow = np.ndarray((2, h, w), dtype=np.float32)
for y in range(h):
for x in range(w):
- flow[1,y,x] = struct.unpack('f', f.read(4))[0]
flow[0,y,x] = struct.unpack('f', f.read(4))[0]
+ flow[1,y,x] = struct.unpack('f', f.read(4))[0]
return flow
def read_weights_file(path):
@@ -565,7 +563,7 @@ def check_image(img, path):
def stylize(content_img, style_imgs, init_img, frame=None):
with tf.device(args.device), tf.Session() as sess:
# setup network
- net = build_vgg19(content_img)
+ net = build_model(content_img)
# style loss
if args.style_mask:
@@ -731,7 +729,7 @@ def get_content_image(content_img):
if w > mx:
h = (float(mx) / float(w)) * h
img = cv2.resize(img, dsize=(mx, int(h)), interpolation=cv2.INTER_AREA)
- img = preprocess(img, vgg19_mean)
+ img = preprocess(img)
return img
def get_style_images(content_img):
@@ -744,7 +742,7 @@ def get_style_images(content_img):
check_image(img, path)
img = img.astype(np.float32)
img = cv2.resize(img, dsize=(cw, ch), interpolation=cv2.INTER_AREA)
- img = preprocess(img, vgg19_mean)
+ img = preprocess(img)
style_imgs.append(img)
return style_imgs
@@ -781,7 +779,7 @@ def get_prev_warped_frame(frame):
path = os.path.join(args.video_input_dir, fn)
flow = read_flow_file(path)
warped_img = warp_image(prev_img, flow).astype(np.float32)
- img = preprocess(warped_img, vgg19_mean)
+ img = preprocess(warped_img)
return img
def get_content_weights(frame, prev_frame):
@@ -807,8 +805,8 @@ def warp_image(src, flow):
return dst
def convert_to_original_colors(content_img, stylized_img):
- content_img = postprocess(content_img, vgg19_mean)
- stylized_img = postprocess(stylized_img, vgg19_mean)
+ content_img = postprocess(content_img)
+ stylized_img = postprocess(stylized_img)
if args.color_convert_type == 'yuv':
cvt_type = cv2.COLOR_BGR2YUV
inv_cvt_type = cv2.COLOR_YUV2BGR
@@ -827,7 +825,7 @@ def convert_to_original_colors(content_img, stylized_img):
_, c2, c3 = cv2.split(content_cvt)
merged = cv2.merge((c1, c2, c3))
dst = cv2.cvtColor(merged, inv_cvt_type).astype(np.float32)
- dst = preprocess(dst, vgg19_mean)
+ dst = preprocess(dst)
return dst
def render_single_image():
@@ -871,4 +869,4 @@ def main():
else: render_single_image()
if __name__ == '__main__':
- main()
+main() \ No newline at end of file