summaryrefslogtreecommitdiff
path: root/Code/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'Code/utils.py')
-rw-r--r--Code/utils.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/Code/utils.py b/Code/utils.py
index c2bf856..042912e 100644
--- a/Code/utils.py
+++ b/Code/utils.py
@@ -1,6 +1,6 @@
import tensorflow as tf
import numpy as np
-from scipy.ndimage import imread
+from scipy.misc import imread
from glob import glob
import os
@@ -69,14 +69,14 @@ def get_full_clips(data_dir, num_clips, num_rec_out=1):
A batch of frame sequences with values normalized in range [-1, 1].
"""
clips = np.empty([num_clips,
- c.TEST_HEIGHT,
- c.TEST_WIDTH,
+ c.FULL_HEIGHT,
+ c.FULL_WIDTH,
(3 * (c.HIST_LEN + num_rec_out))])
# get num_clips random episodes
ep_dirs = np.random.choice(glob(data_dir + '*'), num_clips)
- # get a random clip of length HIST_LEN + 1 from each episode
+ # get a random clip of length HIST_LEN + num_rec_out from each episode
for clip_num, ep_dir in enumerate(ep_dirs):
ep_frame_paths = glob(os.path.join(ep_dir, '*'))
start_index = np.random.choice(len(ep_frame_paths) - (c.HIST_LEN + num_rec_out - 1))
@@ -105,8 +105,8 @@ def process_clip():
take_first = np.random.choice(2, p=[0.95, 0.05])
cropped_clip = np.empty([c.TRAIN_HEIGHT, c.TRAIN_WIDTH, 3 * (c.HIST_LEN + 1)])
for i in xrange(100): # cap at 100 trials in case the clip has no movement anywhere
- crop_x = np.random.choice(c.TEST_WIDTH - c.TRAIN_WIDTH + 1)
- crop_y = np.random.choice(c.TEST_HEIGHT - c.TRAIN_HEIGHT + 1)
+ crop_x = np.random.choice(c.FULL_WIDTH - c.TRAIN_WIDTH + 1)
+ crop_y = np.random.choice(c.FULL_HEIGHT - c.TRAIN_HEIGHT + 1)
cropped_clip = clip[crop_y:crop_y + c.TRAIN_HEIGHT, crop_x:crop_x + c.TRAIN_WIDTH, :]
if take_first or clip_l2_diff(cropped_clip) > c.MOVEMENT_THRESHOLD: