summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Code/avg_runner.py6
-rw-r--r--Code/constants.py16
-rw-r--r--Code/process_data.py5
-rw-r--r--Code/utils.py12
4 files changed, 26 insertions, 13 deletions
diff --git a/Code/avg_runner.py b/Code/avg_runner.py
index 18bc592..9a27426 100644
--- a/Code/avg_runner.py
+++ b/Code/avg_runner.py
@@ -42,8 +42,8 @@ class AVGRunner:
self.summary_writer,
c.TRAIN_HEIGHT,
c.TRAIN_WIDTH,
- c.TEST_HEIGHT,
- c.TEST_WIDTH,
+ c.FULL_HEIGHT,
+ c.FULL_WIDTH,
c.SCALE_FMS_G,
c.SCALE_KERNEL_SIZES_G)
@@ -161,7 +161,7 @@ def main():
# set test frame dimensions
assert os.path.exists(c.TEST_DIR)
- c.TEST_HEIGHT, c.TEST_WIDTH = c.get_test_frame_dims()
+ c.FULL_HEIGHT, c.FULL_WIDTH = c.get_test_frame_dims()
##
# Init and run the predictor
diff --git a/Code/constants.py b/Code/constants.py
index dfd3660..761448b 100644
--- a/Code/constants.py
+++ b/Code/constants.py
@@ -49,16 +49,23 @@ def get_test_frame_dims():
return shape[0], shape[1]
+def get_train_frame_dims():
+ img_path = glob(os.path.join(TRAIN_DIR, '*/*'))[0]
+ img = imread(img_path, mode='RGB')
+ shape = np.shape(img)
+
+ return shape[0], shape[1]
+
def set_test_dir(directory):
"""
Edits all constants dependent on TEST_DIR.
@param directory: The new test directory.
"""
- global TEST_DIR, TEST_HEIGHT, TEST_WIDTH
+ global TEST_DIR, FULL_HEIGHT, FULL_WIDTH
TEST_DIR = directory
- TEST_HEIGHT, TEST_WIDTH = get_test_frame_dims()
+ FULL_HEIGHT, FULL_WIDTH = get_test_frame_dims()
# root directory for all data
DATA_DIR = get_dir('../Data/')
@@ -75,8 +82,9 @@ MOVEMENT_THRESHOLD = 100
# total number of processed clips in TRAIN_DIR_CLIPS
NUM_CLIPS = len(glob(TRAIN_DIR_CLIPS + '*'))
-# the height and width of the full frames to test on. Set in avg_runner.py main.
-TEST_HEIGHT = TEST_WIDTH = 0
+# the height and width of the full frames to test on. Set in avg_runner.py or process_data.py main.
+FULL_HEIGHT = 210
+FULL_WIDTH = 160
# the height and width of the patches to train on
TRAIN_HEIGHT = TRAIN_WIDTH = 32
diff --git a/Code/process_data.py b/Code/process_data.py
index 446de91..8375035 100644
--- a/Code/process_data.py
+++ b/Code/process_data.py
@@ -2,6 +2,7 @@ import numpy as np
import getopt
import sys
from glob import glob
+import os
import constants as c
from utils import process_clip
@@ -64,6 +65,10 @@ def main():
usage()
sys.exit(2)
+ # set train frame dimensions
+ assert os.path.exists(c.TRAIN_DIR)
+ c.FULL_HEIGHT, c.FULL_WIDTH = c.get_train_frame_dims()
+
##
# Process data for training
##
diff --git a/Code/utils.py b/Code/utils.py
index c2bf856..042912e 100644
--- a/Code/utils.py
+++ b/Code/utils.py
@@ -1,6 +1,6 @@
import tensorflow as tf
import numpy as np
-from scipy.ndimage import imread
+from scipy.misc import imread
from glob import glob
import os
@@ -69,14 +69,14 @@ def get_full_clips(data_dir, num_clips, num_rec_out=1):
A batch of frame sequences with values normalized in range [-1, 1].
"""
clips = np.empty([num_clips,
- c.TEST_HEIGHT,
- c.TEST_WIDTH,
+ c.FULL_HEIGHT,
+ c.FULL_WIDTH,
(3 * (c.HIST_LEN + num_rec_out))])
# get num_clips random episodes
ep_dirs = np.random.choice(glob(data_dir + '*'), num_clips)
- # get a random clip of length HIST_LEN + 1 from each episode
+ # get a random clip of length HIST_LEN + num_rec_out from each episode
for clip_num, ep_dir in enumerate(ep_dirs):
ep_frame_paths = glob(os.path.join(ep_dir, '*'))
start_index = np.random.choice(len(ep_frame_paths) - (c.HIST_LEN + num_rec_out - 1))
@@ -105,8 +105,8 @@ def process_clip():
take_first = np.random.choice(2, p=[0.95, 0.05])
cropped_clip = np.empty([c.TRAIN_HEIGHT, c.TRAIN_WIDTH, 3 * (c.HIST_LEN + 1)])
for i in xrange(100): # cap at 100 trials in case the clip has no movement anywhere
- crop_x = np.random.choice(c.TEST_WIDTH - c.TRAIN_WIDTH + 1)
- crop_y = np.random.choice(c.TEST_HEIGHT - c.TRAIN_HEIGHT + 1)
+ crop_x = np.random.choice(c.FULL_WIDTH - c.TRAIN_WIDTH + 1)
+ crop_y = np.random.choice(c.FULL_HEIGHT - c.TRAIN_HEIGHT + 1)
cropped_clip = clip[crop_y:crop_y + c.TRAIN_HEIGHT, crop_x:crop_x + c.TRAIN_WIDTH, :]
if take_first or clip_l2_diff(cropped_clip) > c.MOVEMENT_THRESHOLD: