summaryrefslogtreecommitdiff
path: root/Code
diff options
context:
space:
mode:
authorjules@lens <julescarbon@gmail.com>2018-04-24 20:19:24 +0200
committerjules@lens <julescarbon@gmail.com>2018-04-24 20:19:24 +0200
commit83e91e0a8effcd20466e56b6ecc3e349bbfa5e0e (patch)
tree377977c1068fa2411ac6b0e7c6da4aa97873ae8b /Code
parent9b0d10f357871231bbec06c610363588812216e1 (diff)
updates to tensorflow code and processing additionsHEADmaster
Diffstat (limited to 'Code')
-rw-r--r--Code/avg_runner.py28
-rw-r--r--Code/constants.py31
-rw-r--r--Code/d_model.py13
-rw-r--r--Code/g_model.py35
-rw-r--r--Code/loss_functions.py10
-rw-r--r--Code/tfutils.py12
-rw-r--r--Code/utils.py34
7 files changed, 124 insertions, 39 deletions
diff --git a/Code/avg_runner.py b/Code/avg_runner.py
index 6809187..ed72b63 100644
--- a/Code/avg_runner.py
+++ b/Code/avg_runner.py
@@ -3,7 +3,11 @@ import getopt
import sys
import os
-from utils import get_train_batch, get_test_batch
+import pprint
+pp = pprint.PrettyPrinter(indent=2)
+
+from glob import glob
+from utils import get_train_batch, get_test_batch, get_all_clips
import constants as c
from g_model import GeneratorModel
from d_model import DiscriminatorModel
@@ -26,8 +30,11 @@ class AVGRunner:
self.num_steps = num_steps
self.num_test_rec = num_test_rec
- self.sess = tf.Session()
- self.summary_writer = tf.train.SummaryWriter(c.SUMMARY_SAVE_DIR, graph=self.sess.graph)
+ config = tf.ConfigProto()
+ config.gpu_options.allow_growth = True
+ self.sess = tf.Session(config = config)
+ #self.sess = tf.Session()
+ self.summary_writer = tf.summary.FileWriter(c.SUMMARY_SAVE_DIR, graph=self.sess.graph)
if c.ADVERSARIAL:
print 'Init discriminator...'
@@ -97,11 +104,20 @@ class AVGRunner:
self.g_model.test_batch(
batch, self.global_step, num_rec_out=self.num_test_rec)
+ def process(self):
+ """
+ Process a directory of images using the generator network.
+ """
+ batch = get_all_clips(c.PROCESS_DIR)
+ self.g_model.test_batch(
+ batch, self.global_step, num_rec_out=self.num_test_rec, process_only=True)
+
def usage():
print 'Options:'
print '-l/--load_path= <Relative/path/to/saved/model>'
print '-t/--test_dir= <Directory of test images>'
+ print '-p/--process_dir= <Directory to process>'
print '-r/--recursions= <# recursive predictions to make on test>'
print '-a/--adversarial= <{t/f}> (Whether to use adversarial training. Default=True)'
print '-n/--name= <Subdirectory of ../Data/Save/*/ in which to save output of this run>'
@@ -123,6 +139,7 @@ def main():
load_path = None
test_only = False
+ process_only = False
num_test_rec = 1 # number of recursive predictions to make on test
num_steps = 1000001
try:
@@ -140,6 +157,9 @@ def main():
load_path = arg
if opt in ('-t', '--test_dir'):
c.set_test_dir(arg)
+ if opt in ('-p', '--process_dir'):
+ c.set_process_dir(arg)
+ process_only = True
if opt in ('-r', '--recursions'):
num_test_rec = int(arg)
if opt in ('-a', '--adversarial'):
@@ -177,6 +197,8 @@ def main():
runner = AVGRunner(num_steps, load_path, num_test_rec)
if test_only:
runner.test()
+ elif process_only:
+ runner.process()
else:
runner.train()
diff --git a/Code/constants.py b/Code/constants.py
index 761448b..5d5446b 100644
--- a/Code/constants.py
+++ b/Code/constants.py
@@ -42,18 +42,22 @@ def clear_dir(directory):
except Exception as e:
print(e)
+def get_process_frame_dims():
+ img_path = glob(os.path.join(PROCESS_DIR, '*'))[0]
+ img = imread(img_path, mode='RGB')
+ shape = np.shape(img)
+ return shape[0], shape[1]
+
def get_test_frame_dims():
img_path = glob(os.path.join(TEST_DIR, '*/*'))[0]
img = imread(img_path, mode='RGB')
shape = np.shape(img)
-
return shape[0], shape[1]
def get_train_frame_dims():
img_path = glob(os.path.join(TRAIN_DIR, '*/*'))[0]
img = imread(img_path, mode='RGB')
shape = np.shape(img)
-
return shape[0], shape[1]
def set_test_dir(directory):
@@ -67,15 +71,28 @@ def set_test_dir(directory):
TEST_DIR = directory
FULL_HEIGHT, FULL_WIDTH = get_test_frame_dims()
+def set_process_dir(directory):
+ """
+ Edits all constants dependent on TEST_DIR.
+
+ @param directory: The new test directory.
+ """
+ global PROCESS_DIR, FULL_HEIGHT, FULL_WIDTH
+
+ TEST_DIR = directory
+ FULL_HEIGHT, FULL_WIDTH = get_process_frame_dims()
+
# root directory for all data
DATA_DIR = get_dir('../Data/')
# directory of unprocessed training frames
-TRAIN_DIR = os.path.join(DATA_DIR, 'Ms_Pacman/Train/')
+TRAIN_DIR = os.path.join(DATA_DIR, 'WoodFlat/Train/')
# directory of unprocessed test frames
-TEST_DIR = os.path.join(DATA_DIR, 'Ms_Pacman/Test/')
+TEST_DIR = os.path.join(DATA_DIR, 'WoodFlat/Test/')
+# directory of all the images we want to process
+PROCESS_DIR = os.path.join(DATA_DIR, 'WoodFlat/Process/')
# Directory of processed training clips.
# hidden so finder doesn't freeze w/ so many files. DON'T USE `ls` COMMAND ON THIS DIR!
-TRAIN_DIR_CLIPS = get_dir(os.path.join(DATA_DIR, '.Clips/'))
+TRAIN_DIR_CLIPS = get_dir(os.path.join(DATA_DIR, 'Clips/WoodClips/'))
# For processing clips. l2 diff between frames must be greater than this
MOVEMENT_THRESHOLD = 100
@@ -83,8 +100,8 @@ MOVEMENT_THRESHOLD = 100
NUM_CLIPS = len(glob(TRAIN_DIR_CLIPS + '*'))
# the height and width of the full frames to test on. Set in avg_runner.py or process_data.py main.
-FULL_HEIGHT = 210
-FULL_WIDTH = 160
+FULL_HEIGHT = 320
+FULL_WIDTH = 180
# the height and width of the patches to train on
TRAIN_HEIGHT = TRAIN_WIDTH = 32
diff --git a/Code/d_model.py b/Code/d_model.py
index 1345ceb..02b882d 100644
--- a/Code/d_model.py
+++ b/Code/d_model.py
@@ -89,8 +89,8 @@ class DiscriminatorModel:
name='train_op')
# add summaries to visualize in TensorBoard
- loss_summary = tf.scalar_summary('loss_D', self.global_loss)
- self.summaries = tf.merge_summary([loss_summary])
+ loss_summary = tf.summary.scalar('loss_D', self.global_loss)
+ self.summaries = tf.summary.merge([loss_summary])
def build_feed_dict(self, input_frames, gt_output_frames, generator):
"""
@@ -123,14 +123,19 @@ class DiscriminatorModel:
for scale_num in xrange(self.num_scale_nets):
scale_net = self.scale_nets[scale_num]
+ broken = 0
# resize gt_output_frames
scaled_gt_output_frames = np.empty([batch_size, scale_net.height, scale_net.width, 3])
for i, img in enumerate(gt_output_frames):
# for skimage.transform.resize, images need to be in range [0, 1], so normalize to
# [0, 1] before resize and back to [-1, 1] after
sknorm_img = (img / 2) + 0.5
- resized_frame = resize(sknorm_img, [scale_net.height, scale_net.width, 3])
- scaled_gt_output_frames[i] = (resized_frame - 0.5) * 2
+ try:
+ resized_frame = resize(sknorm_img, [scale_net.height, scale_net.width, 3])
+ scaled_gt_output_frames[i-broken] = (resized_frame - 0.5) * 2
+ except:
+ broken += 1
+ #print str(broken) + " " + "broken images"
# combine with resized gt_output_frames to get inputs for prediction
scaled_input_frames = np.concatenate([g_scale_preds[scale_num],
diff --git a/Code/g_model.py b/Code/g_model.py
index 5dc8265..8ff75a1 100644
--- a/Code/g_model.py
+++ b/Code/g_model.py
@@ -114,7 +114,7 @@ class GeneratorModel:
if scale_num > 0:
last_gen_frames = tf.image.resize_images(
last_gen_frames,[scale_height, scale_width])
- inputs = tf.concat(3, [inputs, last_gen_frames])
+ inputs = tf.concat(axis=3, values=[inputs, last_gen_frames])
# generated frame predictions
preds = inputs
@@ -196,7 +196,7 @@ class GeneratorModel:
name='train_op')
# train loss summary
- loss_summary = tf.scalar_summary('train_loss_G', self.global_loss)
+ loss_summary = tf.summary.scalar('train_loss_G', self.global_loss)
self.summaries_train.append(loss_summary)
##
@@ -215,22 +215,22 @@ class GeneratorModel:
self.sharpdiff_error_test = sharp_diff_error(self.scale_preds_test[-1],
self.gt_frames_test)
# train error summaries
- summary_psnr_train = tf.scalar_summary('train_PSNR',
+ summary_psnr_train = tf.summary.scalar('train_PSNR',
self.psnr_error_train)
- summary_sharpdiff_train = tf.scalar_summary('train_SharpDiff',
+ summary_sharpdiff_train = tf.summary.scalar('train_SharpDiff',
self.sharpdiff_error_train)
self.summaries_train += [summary_psnr_train, summary_sharpdiff_train]
# test error
- summary_psnr_test = tf.scalar_summary('test_PSNR',
+ summary_psnr_test = tf.summary.scalar('test_PSNR',
self.psnr_error_test)
- summary_sharpdiff_test = tf.scalar_summary('test_SharpDiff',
+ summary_sharpdiff_test = tf.summary.scalar('test_SharpDiff',
self.sharpdiff_error_test)
self.summaries_test += [summary_psnr_test, summary_sharpdiff_test]
# add summaries to visualize in TensorBoard
- self.summaries_train = tf.merge_summary(self.summaries_train)
- self.summaries_test = tf.merge_summary(self.summaries_test)
+ self.summaries_train = tf.summary.merge(self.summaries_train)
+ self.summaries_test = tf.summary.merge(self.summaries_test)
def train_step(self, batch, discriminator=None):
"""
@@ -307,13 +307,17 @@ class GeneratorModel:
scale_width = int(self.width_train * scale_factor)
# resize gt_output_frames for scale and append to scale_gts_train
+ broken = 0
scaled_gt_frames = np.empty([c.BATCH_SIZE, scale_height, scale_width, 3])
for i, img in enumerate(gt_frames):
# for skimage.transform.resize, images need to be in range [0, 1], so normalize
# to [0, 1] before resize and back to [-1, 1] after
sknorm_img = (img / 2) + 0.5
- resized_frame = resize(sknorm_img, [scale_height, scale_width, 3])
- scaled_gt_frames[i] = (resized_frame - 0.5) * 2
+ try:
+ resized_frame = resize(sknorm_img, [scale_height, scale_width, 3])
+ scaled_gt_frames[i-broken] = (resized_frame - 0.5) * 2
+ except:
+ broken += 1
scale_gts.append(scaled_gt_frames)
# for every clip in the batch, save the inputs, scale preds and scale gts
@@ -342,7 +346,7 @@ class GeneratorModel:
return global_step
- def test_batch(self, batch, global_step, num_rec_out=1, save_imgs=True):
+ def test_batch(self, batch, global_step, num_rec_out=1, save_imgs=True, process_only=False):
"""
Runs a training step using the global loss on each of the scale networks.
@@ -408,7 +412,14 @@ class GeneratorModel:
# Save images
##
- if save_imgs:
+ if process_only:
+ pred_dir = c.get_dir(os.path.join(
+ c.IMG_SAVE_DIR, 'Process/Step_' + str(global_step)))
+ for rec_num in xrange(num_rec_out):
+ gen_img = rec_preds[rec_num][0]
+ imsave(os.path.join(pred_dir, 'gen_{:05d}.png'.format(rec_num)), gen_img)
+
+ elif save_imgs:
for pred_num in xrange(len(input_frames)):
pred_dir = c.get_dir(os.path.join(
c.IMG_SAVE_DIR, 'Tests/Step_' + str(global_step), str(pred_num)))
diff --git a/Code/loss_functions.py b/Code/loss_functions.py
index 994d226..0d33b49 100644
--- a/Code/loss_functions.py
+++ b/Code/loss_functions.py
@@ -61,7 +61,7 @@ def lp_loss(gen_frames, gt_frames, l_num):
scale_losses.append(tf.reduce_sum(tf.abs(gen_frames[i] - gt_frames[i])**l_num))
# condense into one tensor and avg
- return tf.reduce_mean(tf.pack(scale_losses))
+ return tf.reduce_mean(tf.stack(scale_losses))
def gdl_loss(gen_frames, gt_frames, alpha):
@@ -80,8 +80,8 @@ def gdl_loss(gen_frames, gt_frames, alpha):
# create filters [-1, 1] and [[1],[-1]] for diffing to the left and down respectively.
pos = tf.constant(np.identity(3), dtype=tf.float32)
neg = -1 * pos
- filter_x = tf.expand_dims(tf.pack([neg, pos]), 0) # [-1, 1]
- filter_y = tf.pack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]]
+ filter_x = tf.expand_dims(tf.stack([neg, pos]), 0) # [-1, 1]
+ filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]]
strides = [1, 1, 1, 1] # stride of (1, 1)
padding = 'SAME'
@@ -96,7 +96,7 @@ def gdl_loss(gen_frames, gt_frames, alpha):
scale_losses.append(tf.reduce_sum((grad_diff_x ** alpha + grad_diff_y ** alpha)))
# condense into one tensor and avg
- return tf.reduce_mean(tf.pack(scale_losses))
+ return tf.reduce_mean(tf.stack(scale_losses))
def adv_loss(preds, labels):
@@ -115,4 +115,4 @@ def adv_loss(preds, labels):
scale_losses.append(loss)
# condense into one tensor and avg
- return tf.reduce_mean(tf.pack(scale_losses))
+ return tf.reduce_mean(tf.stack(scale_losses))
diff --git a/Code/tfutils.py b/Code/tfutils.py
index 22baf95..0d2d8dd 100644
--- a/Code/tfutils.py
+++ b/Code/tfutils.py
@@ -92,12 +92,12 @@ def batch_pad_to_bounding_box(images, offset_height, offset_width, target_height
rpad = np.zeros([batch_size, target_height, num_rpad, channels])
padded = images
- if num_tpad > 0 and num_bpad > 0: padded = tf.concat(1, [tpad, padded, bpad])
- elif num_tpad > 0: padded = tf.concat(1, [tpad, padded])
- elif num_bpad > 0: padded = tf.concat(1, [padded, bpad])
- if num_lpad > 0 and num_rpad > 0: padded = tf.concat(2, [lpad, padded, rpad])
- elif num_lpad > 0: padded = tf.concat(2, [lpad, padded])
- elif num_rpad > 0: padded = tf.concat(2, [padded, rpad])
+ if num_tpad > 0 and num_bpad > 0: padded = tf.concat(axis=1, values=[tpad, padded, bpad])
+ elif num_tpad > 0: padded = tf.concat(axis=1, values=[tpad, padded])
+ elif num_bpad > 0: padded = tf.concat(axis=1, values=[padded, bpad])
+ if num_lpad > 0 and num_rpad > 0: padded = tf.concat(axis=2, values=[lpad, padded, rpad])
+ elif num_lpad > 0: padded = tf.concat(axis=2, values=[lpad, padded])
+ elif num_rpad > 0: padded = tf.concat(axis=2, values=[padded, rpad])
return padded
diff --git a/Code/utils.py b/Code/utils.py
index 39a7e11..d5c5d05 100644
--- a/Code/utils.py
+++ b/Code/utils.py
@@ -90,6 +90,36 @@ def get_full_clips(data_dir, num_clips, num_rec_out=1):
return clips
+def get_all_clips(data_dir):
+ """
+ Loads all clips from a directory.
+
+ @param data_dir: The directory of the data to read. Should be either c.TRAIN_DIR or c.TEST_DIR.
+ @param num_clips: The number of clips to read.
+ @param num_rec_out: The number of outputs to predict. Outputs > 1 are computed recursively,
+ using the previously-generated frames as input. Default = 1.
+
+ @return: An array of shape
+ [num_clips, c.TRAIN_HEIGHT, c.TRAIN_WIDTH, (3 * (c.HIST_LEN + num_rec_out))].
+ A batch of frame sequences with values normalized in range [-1, 1].
+ """
+ # get all the clips
+ clip_frame_paths = sorted(glob(os.path.join(data_dir, '*')))
+
+ clips = np.empty([len(clip_frame_paths),
+ c.FULL_HEIGHT,
+ c.FULL_WIDTH,
+ (3 * (c.HIST_LEN + num_rec_out))])
+
+ # read in frames
+ for frame_num, frame_path in enumerate(clip_frame_paths):
+ frame = imread(frame_path, mode='RGB')
+ norm_frame = normalize_frames(frame)
+
+ clips[clip_num, :, :, frame_num * 3:(frame_num + 1) * 3] = norm_frame
+
+ return clips
+
def process_clip():
"""
Gets a clip from the train dataset, cropped randomly to c.TRAIN_HEIGHT x c.TRAIN_WIDTH.
@@ -193,8 +223,8 @@ def sharp_diff_error(gen_frames, gt_frames):
# TODO: Could this be simplified with one filter [[-1, 2], [0, -1]]?
pos = tf.constant(np.identity(3), dtype=tf.float32)
neg = -1 * pos
- filter_x = tf.expand_dims(tf.pack([neg, pos]), 0) # [-1, 1]
- filter_y = tf.pack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]]
+ filter_x = tf.expand_dims(tf.stack([neg, pos]), 0) # [-1, 1]
+ filter_y = tf.stack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]]
strides = [1, 1, 1, 1] # stride of (1, 1)
padding = 'SAME'