summaryrefslogtreecommitdiff
path: root/Code/test.py
diff options
context:
space:
mode:
Diffstat (limited to 'Code/test.py')
-rw-r--r--Code/test.py67
1 files changed, 67 insertions, 0 deletions
diff --git a/Code/test.py b/Code/test.py
new file mode 100644
index 0000000..3b2cd86
--- /dev/null
+++ b/Code/test.py
@@ -0,0 +1,67 @@
+import numpy as np
+import constants as c
+from utils import normalize_frames, get_test_batch
+from glob import glob
+from scipy.misc import imread, imsave
+import os
+
+def save_batch(batch, num_rec_out):
+ # TEST
+ for clip_num, clip in enumerate(batch):
+ for frame_num in xrange(c.HIST_LEN + num_rec_out):
+ imsave(c.get_dir('TEST/' + str(clip_num) + '/') + str(frame_num) + '.png',
+ clip[:, :, frame_num * 3:(frame_num + 1) * 3])
+
+def get_full_clips(data_dir, num_clips, num_rec_out=1):
+ """
+ Loads a batch of random clips from the unprocessed train or test data.
+
+ @param data_dir: The directory of the data to read. Should be either c.TRAIN_DIR or c.TEST_DIR.
+ @param num_clips: The number of clips to read.
+ @param num_rec_out: The number of outputs to predict. Outputs > 1 are computed recursively,
+ using the previously-generated frames as input. Default = 1.
+
+ @return: An array of shape
+ [num_clips, c.TRAIN_HEIGHT, c.TRAIN_WIDTH, (3 * (c.HIST_LEN + num_rec_out))].
+ A batch of frame sequences with values normalized in range [-1, 1].
+ """
+ clips = np.empty([num_clips,
+ c.FULL_HEIGHT,
+ c.FULL_WIDTH,
+ (3 * (c.HIST_LEN + num_rec_out))])
+
+ # get num_clips random episodes
+ ep_dirs = np.random.choice(glob(data_dir + '*'), num_clips)
+ print ep_dirs
+
+ # get a random clip of length HIST_LEN + num_rec_out from each episode
+ for clip_num, ep_dir in enumerate(ep_dirs):
+ ep_frame_paths = glob(os.path.join(ep_dir, '*'))
+ start_index = np.random.choice(len(ep_frame_paths) - (c.HIST_LEN + num_rec_out - 1))
+ clip_frame_paths = ep_frame_paths[start_index:start_index + (c.HIST_LEN + num_rec_out)]
+
+ # read in frames
+ for frame_num, frame_path in enumerate(clip_frame_paths):
+ frame = imread(frame_path, mode='RGB')
+ norm_frame = normalize_frames(frame)
+
+ clips[clip_num, :, :, frame_num * 3:(frame_num + 1) * 3] = norm_frame
+
+ # TEST
+ save_batch(clips, num_rec_out)
+
+ return clips
+
+# get_full_clips('../Data/Ms_Pacman/Test/', 8, num_rec_out=64)
+
+def test():
+ """
+ Runs one test step on the generator network.
+ """
+ batch = get_test_batch(c.BATCH_SIZE, num_rec_out=2)
+ save_batch(batch, 2)
+
+ # self.g_model.test_batch(
+ # batch, self.global_step, num_rec_out=2)
+
+test() \ No newline at end of file