diff options
Diffstat (limited to 'Code')
| -rw-r--r-- | Code/avg_runner.py | 173 | ||||
| -rw-r--r-- | Code/constants.py | 198 | ||||
| -rw-r--r-- | Code/d_model.py | 187 | ||||
| -rw-r--r-- | Code/d_scale_model.py | 153 | ||||
| -rw-r--r-- | Code/g_model.py | 428 | ||||
| -rw-r--r-- | Code/loss_functions.py | 118 | ||||
| -rw-r--r-- | Code/loss_functions_test.py | 304 | ||||
| -rw-r--r-- | Code/process_data.py | 71 | ||||
| -rw-r--r-- | Code/tfutils.py | 133 | ||||
| -rw-r--r-- | Code/tfutils_test.py | 102 | ||||
| -rw-r--r-- | Code/utils.py | 212 |
11 files changed, 2079 insertions, 0 deletions
diff --git a/Code/avg_runner.py b/Code/avg_runner.py new file mode 100644 index 0000000..5de994b --- /dev/null +++ b/Code/avg_runner.py @@ -0,0 +1,173 @@ +import tensorflow as tf +import getopt +import sys + +from utils import get_train_batch, get_test_batch +import constants as c +from g_model import GeneratorModel +from d_model import DiscriminatorModel + + +class AVGRunner: + def __init__(self, model_load_path, num_test_rec): + """ + Initializes the Adversarial Video Generation Runner. + + @param model_load_path: The path from which to load a previously-saved model. + Default = None. + @param num_test_rec: The number of recursive generations to produce when testing. Recursive + generations use previous generations as input to predict further into + the future. + """ + + self.global_step = 0 + self.num_test_rec = num_test_rec + + self.sess = tf.Session() + self.summary_writer = tf.train.SummaryWriter(c.SUMMARY_SAVE_DIR, graph=self.sess.graph) + + if c.ADVERSARIAL: + print 'Init discriminator...' + self.d_model = DiscriminatorModel(self.sess, + self.summary_writer, + c.TRAIN_HEIGHT, + c.TRAIN_WIDTH, + c.SCALE_CONV_FMS_D, + c.SCALE_KERNEL_SIZES_D, + c.SCALE_FC_LAYER_SIZES_D) + + print 'Init generator...' + self.g_model = GeneratorModel(self.sess, + self.summary_writer, + c.TRAIN_HEIGHT, + c.TRAIN_WIDTH, + c.TEST_HEIGHT, + c.TEST_WIDTH, + c.SCALE_FMS_G, + c.SCALE_KERNEL_SIZES_G) + + print 'Init variables...' + self.saver = tf.train.Saver(keep_checkpoint_every_n_hours=2) + self.sess.run(tf.initialize_all_variables()) + + # if load path specified, load a saved model + if model_load_path is not None: + self.saver.restore(self.sess, model_load_path) + print 'Model restored from ' + model_load_path + + def train(self): + """ + Runs a training loop on the model networks. + """ + while True: + if c.ADVERSARIAL: + # update discriminator + batch = get_train_batch() + print 'Training discriminator...' + self.d_model.train_step(batch, self.g_model) + + # update generator + batch = get_train_batch() + print 'Training generator...' + self.global_step = self.g_model.train_step( + batch, discriminator=(self.d_model if c.ADVERSARIAL else None)) + + # save the models + if self.global_step % c.MODEL_SAVE_FREQ == 0: + print '-' * 30 + print 'Saving models...' + self.saver.save(self.sess, + c.MODEL_SAVE_DIR + 'model.ckpt', + global_step=self.global_step) + print 'Saved models!' + print '-' * 30 + + # test generator model + if self.global_step % c.TEST_FREQ == 0: + self.test() + + def test(self): + """ + Runs one test step on the generator network. + """ + batch = get_test_batch(c.BATCH_SIZE, num_rec_out=self.num_test_rec) + self.g_model.test_batch( + batch, self.global_step, num_rec_out=self.num_test_rec) + + +def usage(): + print 'Options:' + print '-l/--load_path= <Relative/path/to/saved/model>' + print '-t/--test_dir= <Directory of test images>' + print '-r/--recursions= <# recursive predictions to make on test>' + print '-a/--adversarial= <{t/f}> (Whether to use adversarial training. Default=True)' + print '-n/--name= <Subdirectory of ../Data/Save/*/ in which to save output of this run>' + print '-O/--overwrite (Overwrites all previous data for the model with this save name)' + print '-T/--test_only (Only runs a test step -- no training)' + print '-H/--help (prints usage)' + print '--stats_freq= <how often to print loss/train error stats, in # steps>' + print '--summary_freq= <how often to save loss/error summaries, in # steps>' + print '--img_save_freq= <how often to save generated images, in # steps>' + print '--test_freq= <how often to test the model on test data, in # steps>' + print '--model_save_freq= <how often to save the model, in # steps>' + + +def main(): + ## + # Handle command line input. + ## + + load_path = None + test_only = False + num_test_rec = 1 # number of recursive predictions to make on test + try: + opts, _ = getopt.getopt(sys.argv[1:], 'l:t:r:a:n:OTH', + ['load_path=', 'test_dir=', 'recursions=', 'adversarial=', 'name=', + 'overwrite', 'test_only', 'help', 'stats_freq=', 'summary_freq=', + 'img_save_freq=', 'test_freq=', 'model_save_freq=']) + except getopt.GetoptError: + usage() + sys.exit(2) + + for opt, arg in opts: + if opt in ('-l', '--load_path'): + load_path = arg + if opt in ('-t', '--test_dir'): + c.set_test_dir(arg) + if opt in ('-r', '--recursions'): + num_test_rec = int(arg) + if opt in ('-a', '--adversarial'): + c.ADVERSARIAL = (arg.lower() == 'true' or arg.lower() == 't') + if opt in ('-n', '--name'): + c.set_save_name(arg) + if opt in ('-O', '--overwrite'): + c.clear_save_name() + if opt in ('-H', '--help'): + usage() + sys.exit(2) + if opt in ('-T', '--test_only'): + test_only = True + if opt == '--stats_freq': + c.STATS_FREQ = int(arg) + if opt == '--summary_freq': + c.SUMMARY_FREQ = int(arg) + if opt == '--img_save_freq': + c.IMG_SAVE_FREQ = int(arg) + if opt == '--test_freq': + c.TEST_FREQ = int(arg) + if opt == '--model_save_freq': + c.MODEL_SAVE_FREQ = int(arg) + + ## + # Init and run the predictor + ## + + runner = AVGRunner(load_path, num_test_rec) + if test_only: + runner.test() + else: + runner.train() + + +if __name__ == '__main__': + main() diff --git a/Code/constants.py b/Code/constants.py new file mode 100644 index 0000000..afe8f9d --- /dev/null +++ b/Code/constants.py @@ -0,0 +1,198 @@ +import numpy as np +import os +from glob import glob +import shutil +from datetime import datetime +from scipy.ndimage import imread + +## +# Data +## + +def get_date_str(): + """ + @return: A string representing the current date/time that can be used as a directory name. + """ + return str(datetime.now()).replace(' ', '_').replace(':', '.')[:-10] + +def get_dir(directory): + """ + Creates the given directory if it does not exist. + + @param directory: The path to the directory. + @return: The path to the directory. + """ + if not os.path.exists(directory): + os.makedirs(directory) + return directory + +def clear_dir(directory): + """ + Removes all files in the given directory. + + @param directory: The path to the directory. + """ + for f in os.listdir(directory): + path = os.path.join(directory, f) + try: + if os.path.isfile(path): + os.unlink(path) + elif os.path.isdir(path): + shutil.rmtree(path) + except Exception as e: + print(e) + +def get_test_frame_dims(): + img_path = glob(TEST_DIR + '*/*')[0] + img = imread(img_path, mode='RGB') + shape = np.shape(img) + + return shape[0], shape[1] + +def set_test_dir(directory): + """ + Edits all constants dependent on TEST_DIR. + + @param directory: The new test directory. + """ + global TEST_DIR, TEST_HEIGHT, TEST_WIDTH + + TEST_DIR = directory + TEST_HEIGHT, TEST_WIDTH = get_test_frame_dims() + +# root directory for all data +DATA_DIR = get_dir('../Data/') +# directory of unprocessed training frames +TRAIN_DIR = DATA_DIR + 'Ms_Pacman/Train/' +# directory of unprocessed test frames +TEST_DIR = DATA_DIR + 'Ms_Pacman/Test/' +# Directory of processed training clips. +# hidden so finder doesn't freeze w/ so many files. DON'T USE `ls` COMMAND ON THIS DIR! +TRAIN_DIR_CLIPS = get_dir(DATA_DIR + '.Clips/') + +# For processing clips. l2 diff between frames must be greater than this +MOVEMENT_THRESHOLD = 100 +# total number of processed clips in TRAIN_DIR_CLIPS +NUM_CLIPS = len(glob(TRAIN_DIR_CLIPS + '*')) + +# the height and width of the full frames to test on +TEST_HEIGHT, TEST_WIDTH = get_test_frame_dims() +# the height and width of the patches to train on +TRAIN_HEIGHT = TRAIN_WIDTH = 32 + +## +# Output +## + +def set_save_name(name): + """ + Edits all constants dependent on SAVE_NAME. + + @param name: The new save name. + """ + global SAVE_NAME, MODEL_SAVE_DIR, SUMMARY_SAVE_DIR, IMG_SAVE_DIR + + SAVE_NAME = name + MODEL_SAVE_DIR = get_dir(SAVE_DIR + 'Models/' + SAVE_NAME) + SUMMARY_SAVE_DIR = get_dir(SAVE_DIR + 'Summaries/' + SAVE_NAME) + IMG_SAVE_DIR = get_dir(SAVE_DIR + 'Images/' + SAVE_NAME) + +def clear_save_name(): + """ + Clears all saved content for SAVE_NAME. + """ + clear_dir(MODEL_SAVE_DIR) + clear_dir(SUMMARY_SAVE_DIR) + clear_dir(IMG_SAVE_DIR) + + +# root directory for all saved content +SAVE_DIR = get_dir('../Save/') + +# inner directory to differentiate between runs +SAVE_NAME = 'Default/' +# directory for saved models +MODEL_SAVE_DIR = get_dir(SAVE_DIR + 'Models/' + SAVE_NAME) +# directory for saved TensorBoard summaries +SUMMARY_SAVE_DIR = get_dir(SAVE_DIR + 'Summaries/' + SAVE_NAME) +# directory for saved images +IMG_SAVE_DIR = get_dir(SAVE_DIR + 'Images/' + SAVE_NAME) + + +STATS_FREQ = 10 # how often to print loss/train error stats, in # steps +SUMMARY_FREQ = 100 # how often to save the summaries, in # steps +IMG_SAVE_FREQ = 1000 # how often to save generated images, in # steps +TEST_FREQ = 5000 # how often to test the model on test data, in # steps +MODEL_SAVE_FREQ = 10000 # how often to save the model, in # steps + +## +# General training +## + +# whether to use adversarial training vs. basic training of the generator +ADVERSARIAL = True +# the training minibatch size +BATCH_SIZE = 8 +# the number of history frames to give as input to the network +HIST_LEN = 4 + +## +# Loss parameters +## + +# for lp loss. e.g, 1 or 2 for l1 and l2 loss, respectively) +L_NUM = 2 +# the power to which each gradient term is raised in GDL loss +ALPHA_NUM = 1 +# the percentage of the adversarial loss to use in the combined loss +LAM_ADV = 0.05 +# the percentage of the lp loss to use in the combined loss +LAM_LP = 1 +# the percentage of the GDL loss to use in the combined loss +LAM_GDL = 1 + +## +# Generator model +## + +# learning rate for the generator model +LRATE_G = 0.00004 # Value in paper is 0.04 +# padding for convolutions in the generator model +PADDING_G = 'SAME' +# feature maps for each convolution of each scale network in the generator model +# e.g SCALE_FMS_G[1][2] is the input of the 3rd convolution in the 2nd scale network. +SCALE_FMS_G = [[3 * HIST_LEN, 128, 256, 128, 3], + [3 * (HIST_LEN + 1), 128, 256, 128, 3], + [3 * (HIST_LEN + 1), 128, 256, 512, 256, 128, 3], + [3 * (HIST_LEN + 1), 128, 256, 512, 256, 128, 3]] +# kernel sizes for each convolution of each scale network in the generator model +SCALE_KERNEL_SIZES_G = [[3, 3, 3, 3], + [5, 3, 3, 5], + [5, 3, 3, 3, 3, 5], + [7, 5, 5, 5, 5, 7]] + + +## +# Discriminator model +## + +# learning rate for the discriminator model +LRATE_D = 0.02 +# padding for convolutions in the discriminator model +PADDING_D = 'VALID' +# feature maps for each convolution of each scale network in the discriminator model +SCALE_CONV_FMS_D = [[3, 64], + [3, 64, 128, 128], + [3, 128, 256, 256], + [3, 128, 256, 512, 128]] +# kernel sizes for each convolution of each scale network in the discriminator model +SCALE_KERNEL_SIZES_D = [[3], + [3, 3, 3], + [5, 5, 5], + [7, 7, 5, 5]] +# layer sizes for each fully-connected layer of each scale network in the discriminator model +# layer connecting conv to fully-connected is dynamically generated when creating the model +SCALE_FC_LAYER_SIZES_D = [[512, 256, 1], + [1024, 512, 1], + [1024, 512, 1], + [1024, 512, 1]] diff --git a/Code/d_model.py b/Code/d_model.py new file mode 100644 index 0000000..7b1cb12 --- /dev/null +++ b/Code/d_model.py @@ -0,0 +1,187 @@ +import tensorflow as tf +import numpy as np +from skimage.transform import resize + +from d_scale_model import DScaleModel +from loss_functions import adv_loss +import constants as c + + +# noinspection PyShadowingNames +class DiscriminatorModel: + def __init__(self, session, summary_writer, height, width, scale_conv_layer_fms, + scale_kernel_sizes, scale_fc_layer_sizes): + """ + Initializes a GeneratorModel. + + @param session: The TensorFlow session. + @param summary_writer: The writer object to record TensorBoard summaries + @param height: The height of the input images. + @param width: The width of the input images. + @param scale_conv_layer_fms: The number of feature maps in each convolutional layer of each + scale network. + @param scale_kernel_sizes: The size of the kernel for each layer of each scale network. + @param scale_fc_layer_sizes: The number of nodes in each fully-connected layer of each scale + network. + + @type session: tf.Session + @type summary_writer: tf.train.SummaryWriter + @type height: int + @type width: int + @type scale_conv_layer_fms: list<list<int>> + @type scale_kernel_sizes: list<list<int>> + @type scale_fc_layer_sizes: list<list<int>> + """ + self.sess = session + self.summary_writer = summary_writer + self.height = height + self.width = width + self.scale_conv_layer_fms = scale_conv_layer_fms + self.scale_kernel_sizes = scale_kernel_sizes + self.scale_fc_layer_sizes = scale_fc_layer_sizes + self.num_scale_nets = len(scale_conv_layer_fms) + + self.define_graph() + + # noinspection PyAttributeOutsideInit + def define_graph(self): + """ + Sets up the model graph in TensorFlow. + """ + with tf.name_scope('discriminator'): + ## + # Setup scale networks. Each will make the predictions for images at a given scale. + ## + + self.scale_nets = [] + for scale_num in xrange(self.num_scale_nets): + with tf.name_scope('scale_net_' + str(scale_num)): + scale_factor = 1. / 2 ** ((self.num_scale_nets - 1) - scale_num) + self.scale_nets.append(DScaleModel(scale_num, + int(self.height * scale_factor), + int(self.width * scale_factor), + self.scale_conv_layer_fms[scale_num], + self.scale_kernel_sizes[scale_num], + self.scale_fc_layer_sizes[scale_num])) + + # A list of the prediction tensors for each scale network + self.scale_preds = [] + for scale_num in xrange(self.num_scale_nets): + self.scale_preds.append(self.scale_nets[scale_num].preds) + + ## + # Data + ## + + self.labels = tf.placeholder(tf.float32, shape=[None, 1], name='labels') + + ## + # Training + ## + + with tf.name_scope('training'): + # global loss is the combined loss from every scale network + self.global_loss = adv_loss(self.scale_preds, self.labels) + self.global_step = tf.Variable(0, trainable=False, name='global_step') + self.optimizer = tf.train.GradientDescentOptimizer(c.LRATE_D, name='optimizer') + self.train_op = self.optimizer.minimize(self.global_loss, + global_step=self.global_step, + name='train_op') + + # add summaries to visualize in TensorBoard + loss_summary = tf.scalar_summary('loss_D', self.global_loss) + self.summaries = tf.merge_summary([loss_summary]) + + def build_feed_dict(self, input_frames, gt_output_frames, generator): + """ + Builds a feed_dict with resized inputs and outputs for each scale network. + + @param input_frames: An array of shape + [batch_size x self.height x self.width x (3 * HIST_LEN)], The frames to + use for generation. + @param gt_output_frames: An array of shape [batch_size x self.height x self.width x 3], The + ground truth outputs for each sequence in input_frames. + @param generator: The generator model. + + @return: The feed_dict needed to run this network, all scale_nets, and the generator + predictions. + """ + feed_dict = {} + batch_size = np.shape(gt_output_frames)[0] + + ## + # Get generated frames from GeneratorModel + ## + + g_feed_dict = {generator.input_frames_train: input_frames, + generator.gt_frames_train: gt_output_frames} + g_scale_preds = self.sess.run(generator.scale_preds_train, feed_dict=g_feed_dict) + + ## + # Create discriminator feed dict + ## + for scale_num in xrange(self.num_scale_nets): + scale_net = self.scale_nets[scale_num] + + # resize gt_output_frames + scaled_gt_output_frames = np.empty([batch_size, scale_net.height, scale_net.width, 3]) + for i, img in enumerate(gt_output_frames): + # for skimage.transform.resize, images need to be in range [0, 1], so normalize to + # [0, 1] before resize and back to [-1, 1] after + sknorm_img = (img / 2) + 0.5 + resized_frame = resize(sknorm_img, [scale_net.height, scale_net.width, 3]) + scaled_gt_output_frames[i] = (resized_frame - 0.5) * 2 + + # combine with resized gt_output_frames to get inputs for prediction + scaled_input_frames = np.concatenate([g_scale_preds[scale_num], + scaled_gt_output_frames]) + + # convert to np array and add to feed_dict + feed_dict[scale_net.input_frames] = scaled_input_frames + + # add labels for each image to feed_dict + batch_size = np.shape(input_frames)[0] + feed_dict[self.labels] = np.concatenate([np.zeros([batch_size, 1]), + np.ones([batch_size, 1])]) + + return feed_dict + + def train_step(self, batch, generator): + """ + Runs a training step using the global loss on each of the scale networks. + + @param batch: An array of shape + [BATCH_SIZE x self.height x self.width x (3 * (HIST_LEN + 1))]. The input + and output frames, concatenated along the channel axis (index 3). + @param generator: The generator model. + + @return: The global step. + """ + ## + # Split into inputs and outputs + ## + + input_frames = batch[:, :, :, :-3] + gt_output_frames = batch[:, :, :, -3:] + + ## + # Train + ## + + feed_dict = self.build_feed_dict(input_frames, gt_output_frames, generator) + + _, global_loss, global_step, summaries = self.sess.run( + [self.train_op, self.global_loss, self.global_step, self.summaries], + feed_dict=feed_dict) + + ## + # User output + ## + + if global_step % c.STATS_FREQ == 0: + print 'DiscriminatorModel: step %d | global loss: %f' % (global_step, global_loss) + if global_step % c.SUMMARY_FREQ == 0: + print 'DiscriminatorModel: saved summaries' + self.summary_writer.add_summary(summaries, global_step) + + return global_step diff --git a/Code/d_scale_model.py b/Code/d_scale_model.py new file mode 100644 index 0000000..766e01a --- /dev/null +++ b/Code/d_scale_model.py @@ -0,0 +1,153 @@ +import tensorflow as tf +from tfutils import w, b, conv_out_size +import constants as c + + +# noinspection PyShadowingNames +class DScaleModel: + """ + A DScaleModel is a network that takes as input one video frame and attempts to discriminate + whether or not the output frame is a real-world image or one generated by a generator network. + Multiple of these are used together in a DiscriminatorModel to make predictions on frames at + increasing scales. + """ + + def __init__(self, scale_index, height, width, conv_layer_fms, kernel_sizes, fc_layer_sizes): + """ + Initializes the DScaleModel. + + @param scale_index: The index number of this height in the GeneratorModel. + @param height: The height of the input images. + @param width: The width of the input images. + @param conv_layer_fms: The number of output feature maps for each convolution. + @param kernel_sizes: The size of the kernel for each convolutional layer. + @param fc_layer_sizes: The number of nodes in each fully-connected layer. + + @type scale_index: int + @type height: int + @type width: int + @type conv_layer_fms: list<int> + @type kernel_sizes: list<int> (len = len(scale_layer_fms) - 1) + @type fc_layer_sizes: list<int> + """ + assert len(kernel_sizes) == len(conv_layer_fms) - 1, \ + 'len(kernel_sizes) must = len(conv_layer_fms) - 1' + + self.scale_index = scale_index + self.height = height + self.width = width + self.conv_layer_fms = conv_layer_fms + self.kernel_sizes = kernel_sizes + self.fc_layer_sizes = fc_layer_sizes + + self.define_graph() + + # noinspection PyAttributeOutsideInit + def define_graph(self): + """ + Sets up the model graph in TensorFlow. + """ + + ## + # Input data + ## + with tf.name_scope('input'): + self.input_frames = tf.placeholder( + tf.float32, shape=[None, self.height, self.width, self.conv_layer_fms[0]]) + + # use variable batch_size for more flexibility + self.batch_size = tf.shape(self.input_frames)[0] + + ## + # Layer setup + ## + + with tf.name_scope('setup'): + # convolution + with tf.name_scope('convolutions'): + conv_ws = [] + conv_bs = [] + last_out_height = self.height + last_out_width = self.width + for i in xrange(len(self.kernel_sizes)): + conv_ws.append(w([self.kernel_sizes[i], + self.kernel_sizes[i], + self.conv_layer_fms[i], + self.conv_layer_fms[i + 1]])) + conv_bs.append(b([self.conv_layer_fms[i + 1]])) + + last_out_height = conv_out_size( + last_out_height, c.PADDING_D, self.kernel_sizes[i], 1) + last_out_width = conv_out_size( + last_out_width, c.PADDING_D, self.kernel_sizes[i], 1) + + # fully-connected + with tf.name_scope('full-connected'): + # Add in an initial layer to go from the last conv to the first fully-connected. + # Use /2 for the height and width because there is a 2x2 pooling layer + self.fc_layer_sizes.insert( + 0, (last_out_height / 2) * (last_out_width / 2) * self.conv_layer_fms[-1]) + + fc_ws = [] + fc_bs = [] + for i in xrange(len(self.fc_layer_sizes) - 1): + fc_ws.append(w([self.fc_layer_sizes[i], + self.fc_layer_sizes[i + 1]])) + fc_bs.append(b([self.fc_layer_sizes[i + 1]])) + + ## + # Forward pass calculation + ## + + def generate_predictions(): + """ + Runs self.input_frames through the network to generate a prediction from 0 + (generated img) to 1 (real img). + + @return: A tensor of predictions of shape [self.batch_size x 1]. + """ + with tf.name_scope('calculation'): + preds = tf.zeros([self.batch_size, 1]) + last_input = self.input_frames + + # convolutions + with tf.name_scope('convolutions'): + for i in xrange(len(conv_ws)): + # Convolve layer and activate with ReLU + preds = tf.nn.conv2d( + last_input, conv_ws[i], [1, 1, 1, 1], padding=c.PADDING_D) + preds = tf.nn.relu(preds + conv_bs[i]) + + last_input = preds + + # pooling layer + with tf.name_scope('pooling'): + preds = tf.nn.max_pool(preds, [1, 2, 2, 1], [1, 2, 2, 1], padding=c.PADDING_D) + + # flatten preds for dense layers + shape = preds.get_shape().as_list() + # -1 can be used as one dimension to size dynamically + preds = tf.reshape(preds, [-1, shape[1] * shape[2] * shape[3]]) + + # fully-connected layers + with tf.name_scope('fully-connected'): + for i in xrange(len(fc_ws)): + preds = tf.matmul(preds, fc_ws[i]) + fc_bs[i] + + # Activate with ReLU (or Sigmoid for last layer) + if i == len(fc_ws) - 1: + preds = tf.sigmoid(preds) + else: + preds = tf.nn.relu(preds) + + # clip preds between [.1, 0.9] for stability + with tf.name_scope('clip'): + preds = tf.clip_by_value(preds, 0.1, 0.9) + + return preds + + self.preds = generate_predictions() + + ## + # Training handled by DiscriminatorModel + ## diff --git a/Code/g_model.py b/Code/g_model.py new file mode 100644 index 0000000..eef24ab --- /dev/null +++ b/Code/g_model.py @@ -0,0 +1,428 @@ +import tensorflow as tf +import numpy as np +from scipy.misc import imsave +from skimage.transform import resize +from copy import deepcopy + +import constants as c +from loss_functions import combined_loss +from utils import psnr_error, sharp_diff_error +from tfutils import w, b + +# noinspection PyShadowingNames +class GeneratorModel: + def __init__(self, session, summary_writer, height_train, width_train, height_test, + width_test, scale_layer_fms, scale_kernel_sizes): + """ + Initializes a GeneratorModel. + + @param session: The TensorFlow Session. + @param summary_writer: The writer object to record TensorBoard summaries + @param height_train: The height of the input images for training. + @param width_train: The width of the input images for training. + @param height_train: The height of the input images for testing. + @param width_train: The width of the input images for testing. + @param scale_layer_fms: The number of feature maps in each layer of each scale network. + @param scale_kernel_sizes: The size of the kernel for each layer of each scale network. + + @type session: tf.Session + @type summary_writer: tf.train.SummaryWriter + @type height_train: int + @type width_train: int + @type height_test: int + @type width_test: int + @type scale_layer_fms: list<list<int>> + @type scale_kernel_sizes: list<list<int>> + """ + self.sess = session + self.summary_writer = summary_writer + self.height_train = height_train + self.width_train = width_train + self.height_test = height_test + self.width_test = width_test + self.scale_layer_fms = scale_layer_fms + self.scale_kernel_sizes = scale_kernel_sizes + self.num_scale_nets = len(scale_layer_fms) + + self.define_graph() + + # noinspection PyAttributeOutsideInit + def define_graph(self): + """ + Sets up the model graph in TensorFlow. + """ + with tf.name_scope('generator'): + ## + # Data + ## + + with tf.name_scope('data'): + self.input_frames_train = tf.placeholder( + tf.float32, shape=[None, self.height_train, self.width_train, 3 * c.HIST_LEN]) + self.gt_frames_train = tf.placeholder( + tf.float32, shape=[None, self.height_train, self.width_train, 3]) + + self.input_frames_test = tf.placeholder( + tf.float32, shape=[None, self.height_test, self.width_test, 3 * c.HIST_LEN]) + self.gt_frames_test = tf.placeholder( + tf.float32, shape=[None, self.height_test, self.width_test, 3]) + + # use variable batch_size for more flexibility + self.batch_size_train = tf.shape(self.input_frames_train)[0] + self.batch_size_test = tf.shape(self.input_frames_test)[0] + + ## + # Scale network setup and calculation + ## + + self.summaries_train = [] + self.scale_preds_train = [] # the generated images at each scale + self.scale_gts_train = [] # the ground truth images at each scale + self.d_scale_preds = [] # the predictions from the discriminator model + + self.summaries_test = [] + self.scale_preds_test = [] # the generated images at each scale + self.scale_gts_test = [] # the ground truth images at each scale + + for scale_num in xrange(self.num_scale_nets): + with tf.name_scope('scale_' + str(scale_num)): + with tf.name_scope('setup'): + ws = [] + bs = [] + + # create weights for kernels + for i in xrange(len(self.scale_kernel_sizes[scale_num])): + ws.append(w([self.scale_kernel_sizes[scale_num][i], + self.scale_kernel_sizes[scale_num][i], + self.scale_layer_fms[scale_num][i], + self.scale_layer_fms[scale_num][i + 1]])) + bs.append(b([self.scale_layer_fms[scale_num][i + 1]])) + + with tf.name_scope('calculation'): + def calculate(height, width, inputs, gts, last_gen_frames): + # scale inputs and gts + scale_factor = 1. / 2 ** ((self.num_scale_nets - 1) - scale_num) + scale_height = int(height * scale_factor) + scale_width = int(width * scale_factor) + + inputs = tf.image.resize_images(inputs, scale_height, scale_width) + scale_gts = tf.image.resize_images(gts, scale_height, scale_width) + + # for all scales but the first, add the frame generated by the last + # scale to the input + if scale_num > 0: + last_gen_frames = tf.image.resize_images(last_gen_frames, + scale_height, + scale_width) + inputs = tf.concat(3, [inputs, last_gen_frames]) + + # generated frame predictions + preds = inputs + + # perform convolutions + with tf.name_scope('convolutions'): + for i in xrange(len(self.scale_kernel_sizes[scale_num])): + # Convolve layer + preds = tf.nn.conv2d( + preds, ws[i], [1, 1, 1, 1], padding=c.PADDING_G) + + # Activate with ReLU (or Tanh for last layer) + if i == len(self.scale_kernel_sizes[scale_num]) - 1: + preds = tf.nn.tanh(preds + bs[i]) + else: + preds = tf.nn.relu(preds + bs[i]) + + return preds, scale_gts + + ## + # Perform train calculation + ## + + # for all scales but the first, add the frame generated by the last + # scale to the input + if scale_num > 0: + last_scale_pred_train = self.scale_preds_train[scale_num - 1] + else: + last_scale_pred_train = None + + # calculate + train_preds, train_gts = calculate(self.height_train, + self.width_train, + self.input_frames_train, + self.gt_frames_train, + last_scale_pred_train) + self.scale_preds_train.append(train_preds) + self.scale_gts_train.append(train_gts) + + # We need to run the network first to get generated frames, run the + # discriminator on those frames to get d_scale_preds, then run this + # again for the loss optimization. + if c.ADVERSARIAL: + self.d_scale_preds.append(tf.placeholder(tf.float32, [None, 1])) + + ## + # Perform test calculation + ## + + # for all scales but the first, add the frame generated by the last + # scale to the input + if scale_num > 0: + last_scale_pred_test = self.scale_preds_test[scale_num - 1] + else: + last_scale_pred_test = None + + # calculate + test_preds, test_gts = calculate(self.height_test, + self.width_test, + self.input_frames_test, + self.gt_frames_test, + last_scale_pred_test) + self.scale_preds_test.append(test_preds) + self.scale_gts_test.append(test_gts) + + ## + # Training + ## + + with tf.name_scope('train'): + # global loss is the combined loss from every scale network + self.global_loss = combined_loss(self.scale_preds_train, + self.scale_gts_train, + self.d_scale_preds) + self.global_step = tf.Variable(0, trainable=False) + self.optimizer = tf.train.AdamOptimizer(learning_rate=c.LRATE_G, name='optimizer') + self.train_op = self.optimizer.minimize(self.global_loss, + global_step=self.global_step, + name='train_op') + + # train loss summary + loss_summary = tf.scalar_summary('train_loss_G', self.global_loss) + self.summaries_train.append(loss_summary) + + ## + # Error + ## + + with tf.name_scope('error'): + # error computation + # get error at largest scale + self.psnr_error_train = psnr_error(self.scale_preds_train[-1], + self.gt_frames_train) + self.sharpdiff_error_train = sharp_diff_error(self.scale_preds_train[-1], + self.gt_frames_train) + self.psnr_error_test = psnr_error(self.scale_preds_test[-1], + self.gt_frames_test) + self.sharpdiff_error_test = sharp_diff_error(self.scale_preds_test[-1], + self.gt_frames_test) + # train error summaries + summary_psnr_train = tf.scalar_summary('train_PSNR', + self.psnr_error_train) + summary_sharpdiff_train = tf.scalar_summary('train_SharpDiff', + self.sharpdiff_error_train) + self.summaries_train += [summary_psnr_train, summary_sharpdiff_train] + + # test error + summary_psnr_test = tf.scalar_summary('test_PSNR', + self.psnr_error_test) + summary_sharpdiff_test = tf.scalar_summary('test_SharpDiff', + self.sharpdiff_error_test) + self.summaries_test += [summary_psnr_test, summary_sharpdiff_test] + + # add summaries to visualize in TensorBoard + self.summaries_train = tf.merge_summary(self.summaries_train) + self.summaries_test = tf.merge_summary(self.summaries_test) + + def train_step(self, batch, discriminator=None): + """ + Runs a training step using the global loss on each of the scale networks. + + @param batch: An array of shape + [c.BATCH_SIZE x self.height x self.width x (3 * (c.HIST_LEN + 1))]. + The input and output frames, concatenated along the channel axis (index 3). + @param discriminator: The discriminator model. Default = None, if not adversarial. + + @return: The global step. + """ + ## + # Split into inputs and outputs + ## + + input_frames = batch[:, :, :, :-3] + gt_frames = batch[:, :, :, -3:] + + ## + # Train + ## + + feed_dict = {self.input_frames_train: input_frames, self.gt_frames_train: gt_frames} + + if c.ADVERSARIAL: + # Run the generator first to get generated frames + scale_preds = self.sess.run(self.scale_preds_train, feed_dict=feed_dict) + + # Run the discriminator nets on those frames to get predictions + d_feed_dict = {} + for scale_num, gen_frames in enumerate(scale_preds): + d_feed_dict[discriminator.scale_nets[scale_num].input_frames] = gen_frames + d_scale_preds = self.sess.run(discriminator.scale_preds, feed_dict=d_feed_dict) + + # Add discriminator predictions to the + for i, preds in enumerate(d_scale_preds): + feed_dict[self.d_scale_preds[i]] = preds + + _, global_loss, global_psnr_error, global_sharpdiff_error, global_step, summaries = \ + self.sess.run([self.train_op, + self.global_loss, + self.psnr_error_train, + self.sharpdiff_error_train, + self.global_step, + self.summaries_train], + feed_dict=feed_dict) + + ## + # User output + ## + if global_step % c.STATS_FREQ == 0: + print 'GeneratorModel : Step ', global_step + print ' Global Loss : ', global_loss + print ' PSNR Error : ', global_psnr_error + print ' Sharpdiff Error: ', global_sharpdiff_error + if global_step % c.SUMMARY_FREQ == 0: + self.summary_writer.add_summary(summaries, global_step) + print 'GeneratorModel: saved summaries' + if global_step % c.IMG_SAVE_FREQ == 0: + print '-' * 30 + print 'Saving images...' + + # if not adversarial, we didn't get the preds for each scale net before for the + # discriminator prediction, so do it now + if not c.ADVERSARIAL: + scale_preds = self.sess.run(self.scale_preds_train, feed_dict=feed_dict) + + # re-generate scale gt_frames to avoid having to run through TensorFlow. + scale_gts = [] + for scale_num in xrange(self.num_scale_nets): + scale_factor = 1. / 2 ** ((self.num_scale_nets - 1) - scale_num) + scale_height = int(self.height_train * scale_factor) + scale_width = int(self.width_train * scale_factor) + + # resize gt_output_frames for scale and append to scale_gts_train + scaled_gt_frames = np.empty([c.BATCH_SIZE, scale_height, scale_width, 3]) + for i, img in enumerate(gt_frames): + # for skimage.transform.resize, images need to be in range [0, 1], so normalize + # to [0, 1] before resize and back to [-1, 1] after + sknorm_img = (img / 2) + 0.5 + resized_frame = resize(sknorm_img, [scale_height, scale_width, 3]) + scaled_gt_frames[i] = (resized_frame - 0.5) * 2 + scale_gts.append(scaled_gt_frames) + + # for every clip in the batch, save the inputs, scale preds and scale gts + for pred_num in xrange(len(input_frames)): + pred_dir = c.get_dir(c.IMG_SAVE_DIR + 'Step_' + str(global_step) + '/' + str( + pred_num) + '/') + + # save input images + for frame_num in xrange(c.HIST_LEN): + img = input_frames[pred_num, :, :, (frame_num * 3):((frame_num + 1) * 3)] + imsave(pred_dir + 'input_' + str(frame_num) + '.png', img) + + # save preds and gts at each scale + # noinspection PyUnboundLocalVariable + for scale_num, scale_pred in enumerate(scale_preds): + gen_img = scale_pred[pred_num] + + path = pred_dir + 'scale' + str(scale_num) + gt_img = scale_gts[scale_num][pred_num] + + imsave(path + '_gen.png', gen_img) + imsave(path + '_gt.png', gt_img) + + print 'Saved images!' + print '-' * 30 + + return global_step + + def test_batch(self, batch, global_step, num_rec_out=1, save_imgs=True): + """ + Runs a training step using the global loss on each of the scale networks. + + @param batch: An array of shape + [batch_size x self.height x self.width x (3 * (c.HIST_LEN+ num_rec_out))]. + A batch of the input and output frames, concatenated along the channel axis + (index 3). + @param global_step: The global step. + @param num_rec_out: The number of outputs to predict. Outputs > 1 are computed recursively, + using previously-generated frames as input. Default = 1. + @param save_imgs: Whether or not to save the input/output images to file. Default = True. + + @return: A tuple of (psnr error, sharpdiff error) for the batch. + """ + if num_rec_out < 1: + raise ValueError('num_rec_out must be >= 1') + + print '-' * 30 + print 'Testing:' + + ## + # Split into inputs and outputs + ## + + input_frames = batch[:, :, :, :3 * c.HIST_LEN] + gt_frames = batch[:, :, :, 3 * c.HIST_LEN:] + + ## + # Generate num_rec_out recursive predictions + ## + + working_input_frames = deepcopy(input_frames) # input frames that will shift w/ recursion + rec_preds = [] + rec_summaries = [] + for rec_num in xrange(num_rec_out): + working_gt_frames = gt_frames[:, :, :, 3 * rec_num:3 * (rec_num + 1)] + + feed_dict = {self.input_frames_test: working_input_frames, + self.gt_frames_test: working_gt_frames} + preds, psnr, sharpdiff, summaries = self.sess.run([self.scale_preds_test[-1], + self.psnr_error_test, + self.sharpdiff_error_test, + self.summaries_test], + feed_dict=feed_dict) + + # remove first input and add new pred as last input + working_input_frames = np.concatenate( + [working_input_frames[:, :, :, 3:], preds], axis=3) + + # add predictions and summaries + rec_preds.append(preds) + rec_summaries.append(summaries) + + print 'Recursion ', rec_num + print 'PSNR Error : ', psnr + print 'Sharpdiff Error: ', sharpdiff + + # write summaries + # TODO: Think of a good way to write rec output summaries - rn, just using first output. + self.summary_writer.add_summary(rec_summaries[0], global_step) + + ## + # Save images + ## + + if save_imgs: + for pred_num in xrange(len(input_frames)): + pred_dir = c.get_dir( + c.IMG_SAVE_DIR + 'Tests/Step_' + str(global_step) + '/' + str(pred_num) + '/') + + # save input images + for frame_num in xrange(c.HIST_LEN): + img = input_frames[pred_num, :, :, (frame_num * 3):((frame_num + 1) * 3)] + imsave(pred_dir + 'input_' + str(frame_num) + '.png', img) + + # save recursive outputs + for rec_num in xrange(num_rec_out): + gen_img = rec_preds[rec_num][pred_num] + gt_img = gt_frames[pred_num, :, :, 3 * rec_num:3 * (rec_num + 1)] + imsave(pred_dir + 'gen_' + str(rec_num) + '.png', gen_img) + imsave(pred_dir + 'gt_' + str(rec_num) + '.png', gt_img) + + print '-' * 30 diff --git a/Code/loss_functions.py b/Code/loss_functions.py new file mode 100644 index 0000000..994d226 --- /dev/null +++ b/Code/loss_functions.py @@ -0,0 +1,118 @@ +import tensorflow as tf +import numpy as np + +from tfutils import log10 +import constants as c + +def combined_loss(gen_frames, gt_frames, d_preds, lam_adv=1, lam_lp=1, lam_gdl=1, l_num=2, alpha=2): + """ + Calculates the sum of the combined adversarial, lp and GDL losses in the given proportion. Used + for training the generative model. + + @param gen_frames: A list of tensors of the generated frames at each scale. + @param gt_frames: A list of tensors of the ground truth frames at each scale. + @param d_preds: A list of tensors of the classifications made by the discriminator model at each + scale. + @param lam_adv: The percentage of the adversarial loss to use in the combined loss. + @param lam_lp: The percentage of the lp loss to use in the combined loss. + @param lam_gdl: The percentage of the GDL loss to use in the combined loss. + @param l_num: 1 or 2 for l1 and l2 loss, respectively). + @param alpha: The power to which each gradient term is raised in GDL loss. + + @return: The combined adversarial, lp and GDL losses. + """ + batch_size = tf.shape(gen_frames[0])[0] # variable batch size as a tensor + + loss = lam_lp * lp_loss(gen_frames, gt_frames, l_num) + loss += lam_gdl * gdl_loss(gen_frames, gt_frames, alpha) + if c.ADVERSARIAL: loss += lam_adv * adv_loss(d_preds, tf.ones([batch_size, 1])) + + return loss + + +def bce_loss(preds, targets): + """ + Calculates the sum of binary cross-entropy losses between predictions and ground truths. + + @param preds: A 1xN tensor. The predicted classifications of each frame. + @param targets: A 1xN tensor The target labels for each frame. (Either 1 or -1). Not "truths" + because the generator passes in lies to determine how well it confuses the + discriminator. + + @return: The sum of binary cross-entropy losses. + """ + return tf.squeeze(-1 * (tf.matmul(targets, log10(preds), transpose_a=True) + + tf.matmul(1 - targets, log10(1 - preds), transpose_a=True))) + + +def lp_loss(gen_frames, gt_frames, l_num): + """ + Calculates the sum of lp losses between the predicted and ground truth frames. + + @param gen_frames: The predicted frames at each scale. + @param gt_frames: The ground truth frames at each scale + @param l_num: 1 or 2 for l1 and l2 loss, respectively). + + @return: The lp loss. + """ + # calculate the loss for each scale + scale_losses = [] + for i in xrange(len(gen_frames)): + scale_losses.append(tf.reduce_sum(tf.abs(gen_frames[i] - gt_frames[i])**l_num)) + + # condense into one tensor and avg + return tf.reduce_mean(tf.pack(scale_losses)) + + +def gdl_loss(gen_frames, gt_frames, alpha): + """ + Calculates the sum of GDL losses between the predicted and ground truth frames. + + @param gen_frames: The predicted frames at each scale. + @param gt_frames: The ground truth frames at each scale + @param alpha: The power to which each gradient term is raised. + + @return: The GDL loss. + """ + # calculate the loss for each scale + scale_losses = [] + for i in xrange(len(gen_frames)): + # create filters [-1, 1] and [[1],[-1]] for diffing to the left and down respectively. + pos = tf.constant(np.identity(3), dtype=tf.float32) + neg = -1 * pos + filter_x = tf.expand_dims(tf.pack([neg, pos]), 0) # [-1, 1] + filter_y = tf.pack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]] + strides = [1, 1, 1, 1] # stride of (1, 1) + padding = 'SAME' + + gen_dx = tf.abs(tf.nn.conv2d(gen_frames[i], filter_x, strides, padding=padding)) + gen_dy = tf.abs(tf.nn.conv2d(gen_frames[i], filter_y, strides, padding=padding)) + gt_dx = tf.abs(tf.nn.conv2d(gt_frames[i], filter_x, strides, padding=padding)) + gt_dy = tf.abs(tf.nn.conv2d(gt_frames[i], filter_y, strides, padding=padding)) + + grad_diff_x = tf.abs(gt_dx - gen_dx) + grad_diff_y = tf.abs(gt_dy - gen_dy) + + scale_losses.append(tf.reduce_sum((grad_diff_x ** alpha + grad_diff_y ** alpha))) + + # condense into one tensor and avg + return tf.reduce_mean(tf.pack(scale_losses)) + + +def adv_loss(preds, labels): + """ + Calculates the sum of BCE losses between the predicted classifications and true labels. + + @param preds: The predicted classifications at each scale. + @param labels: The true labels. (Same for every scale). + + @return: The adversarial loss. + """ + # calculate the loss for each scale + scale_losses = [] + for i in xrange(len(preds)): + loss = bce_loss(preds[i], labels) + scale_losses.append(loss) + + # condense into one tensor and avg + return tf.reduce_mean(tf.pack(scale_losses)) diff --git a/Code/loss_functions_test.py b/Code/loss_functions_test.py new file mode 100644 index 0000000..6b015f2 --- /dev/null +++ b/Code/loss_functions_test.py @@ -0,0 +1,304 @@ +from loss_functions import * + +sess = tf.Session() +BATCH_SIZE = 2 +NUM_SCALES = 5 +MAX_P = 5 +MAX_ALPHA = 1 + + +# noinspection PyClassHasNoInit +class TestBCELoss: + def test_false_correct(self): + targets = tf.constant(np.zeros([5, 1])) + preds = 1e-7 * tf.constant(np.ones([5, 1])) + res = sess.run(bce_loss(preds, targets)) + + log_con = np.log10(1 - 1e-7) + res_tru = -1 * np.sum(np.array([log_con] * 5)) + assert np.array_equal(np.around(res, 7), np.around(res_tru, 7)) + + def test_false_incorrect(self): + targets = tf.constant(np.zeros([5, 1])) + preds = tf.constant(np.ones([5, 1])) - 1e-7 + res = sess.run(bce_loss(preds, targets)) + + log_con = np.log10(1e-7) + res_tru = -1 * np.sum(np.array([log_con] * 5)) + assert np.array_equal(np.around(res, 7), np.around(res_tru, 7)) + + def test_false_half(self): + targets = tf.constant(np.zeros([5, 1])) + preds = 0.5 * tf.constant(np.ones([5, 1])) + res = sess.run(bce_loss(preds, targets)) + + log_con = np.log10(0.5) + res_tru = -1 * np.sum(np.array([log_con] * 5)) + assert np.array_equal(np.around(res, 7), np.around(res_tru, 7)) + + def test_true_correct(self): + targets = tf.constant(np.ones([5, 1])) + preds = tf.constant(np.ones([5, 1])) - 1e-7 + res = sess.run(bce_loss(preds, targets)) + + log = np.log10(1 - 1e-7) + res_tru = -1 * np.sum(np.array([log] * 5)) + assert np.array_equal(np.around(res, 7), np.around(res_tru, 7)) + + def test_true_incorrect(self): + targets = tf.constant(np.ones([5, 1])) + preds = 1e-7 * tf.constant(np.ones([5, 1])) + res = sess.run(bce_loss(preds, targets)) + + log = np.log10(1e-7) + res_tru = -1 * np.sum(np.array([log] * 5)) + assert np.array_equal(np.around(res, 7), np.around(res_tru, 7)) + + def test_true_half(self): + targets = tf.constant(np.ones([5, 1])) + preds = 0.5 * tf.constant(np.ones([5, 1])) + res = sess.run(bce_loss(preds, targets)) + + log = np.log10(0.5) + res_tru = -1 * np.sum(np.array([log] * 5)) + assert np.array_equal(np.around(res, 7), np.around(res_tru, 7)) + + +# noinspection PyClassHasNoInit +class TestLPLoss: + def test_same_images(self): + # generate scales + scale_preds = [] + scale_truths = [] + + res_tru = 0 + for i in xrange(1, NUM_SCALES + 1): + scale_preds.append(tf.constant(np.ones([BATCH_SIZE, 2**i, 2**i, 3]))) + scale_truths.append(tf.constant(np.ones([BATCH_SIZE, 2**i, 2**i, 3]))) + + for p in xrange(1, MAX_P + 1): + res = sess.run(lp_loss(scale_preds, scale_truths, p)) + assert res == res_tru, 'failed on p = %d' % p + + def test_opposite_images(self): + # generate scales + scale_preds = [] + scale_truths = [] + + res_tru = 0 + for i in xrange(1, NUM_SCALES + 1): + scale_preds.append(tf.constant(np.zeros([BATCH_SIZE, 2**i, 2 ** i, 3]))) + scale_truths.append(tf.constant(np.ones([BATCH_SIZE, 2**i, 2 ** i, 3]))) + + res_tru += BATCH_SIZE * 2**i * 2**i * 3 + + for p in xrange(1, MAX_P + 1): + res = sess.run(lp_loss(scale_preds, scale_truths, p)) + assert res == res_tru, 'failed on p = %d' % p + + def test_some_correct(self): + # generate scales + scale_preds = [] + scale_truths = [] + + res_tru = 0 + for i in xrange(1, NUM_SCALES + 1): + # generate batch of 3-deep identity matrices + preds = np.empty([BATCH_SIZE, 2**i, 2**i, 3]) + imat = np.identity(2**i) + for elt in xrange(BATCH_SIZE): + preds[elt] = np.dstack([imat, imat, imat]) + + scale_preds.append(tf.constant(preds)) + scale_truths.append(tf.constant(np.zeros([BATCH_SIZE, 2**i, 2**i, 3]))) + + res_tru += BATCH_SIZE * 2**i * 3 + + for p in xrange(1, MAX_P + 1): + res = sess.run(lp_loss(scale_preds, scale_truths, p)) + assert res == res_tru, 'failed on p = %d' % p + + def test_l_high(self): + # generate scales + scale_preds = [] + scale_truths = [] + + res_tru = 0 + for i in xrange(1, NUM_SCALES + 1): + # opposite images + preds = np.empty([BATCH_SIZE, 2**i, 2**i, 3]) + preds.fill(3) + scale_preds.append(tf.constant(preds)) + scale_truths.append(tf.constant(np.zeros([BATCH_SIZE, 2**i, 2**i, 3]))) + + res_tru += BATCH_SIZE * 2**i * 2**i * 3 + + for p in xrange(1, MAX_P + 1): + res = sess.run(lp_loss(scale_preds, scale_truths, p)) + assert res == res_tru * (3**p), 'failed on p = %d' % p + + +# noinspection PyClassHasNoInit +class TestGDLLoss: + def test_same_uniform(self): + # generate scales + scale_preds = [] + scale_truths = [] + + res_tru = 0 + for i in xrange(1, NUM_SCALES + 1): + scale_preds.append(tf.ones([BATCH_SIZE, 2 ** i, 2 ** i, 3])) + scale_truths.append(tf.ones([BATCH_SIZE, 2 ** i, 2 ** i, 3])) + + for a in xrange(1, MAX_ALPHA + 1): + res = sess.run(gdl_loss(scale_preds, scale_truths, a)) + assert res == res_tru, 'failed on alpha = %d' % a + + def test_same_nonuniform(self): + # generate scales + scale_preds = [] + scale_truths = [] + + res_tru = 0 + for i in xrange(1, NUM_SCALES + 1): + # generate batch of 3-deep identity matrices + arr = np.empty([BATCH_SIZE, 2 ** i, 2 ** i, 3]) + imat = np.identity(2 ** i) + for elt in xrange(BATCH_SIZE): + arr[elt] = np.dstack([imat, imat, imat]) + + scale_preds.append(tf.constant(arr, dtype=tf.float32)) + scale_truths.append(tf.constant(arr, dtype=tf.float32)) + + for a in xrange(1, MAX_ALPHA + 1): + res = sess.run(gdl_loss(scale_preds, scale_truths, a)) + assert res == res_tru, 'failed on alpha = %d' % a + + # TODO: Not 0 loss as expected because the 1s array is padded by 0s, so there is some gradient. + def test_diff_uniform(self): + # generate scales + scale_preds = [] + scale_truths = [] + + res_tru = 0 + for i in xrange(1, NUM_SCALES + 1): + scale_preds.append(tf.zeros([BATCH_SIZE, 2 ** i, 2 ** i, 3])) + scale_truths.append(tf.ones([BATCH_SIZE, 2 ** i, 2 ** i, 3])) + + # every diff should have an abs value of 1, so no need for alpha handling + res_tru += BATCH_SIZE * 2 ** i * 2 * 3 + + for a in xrange(1, MAX_ALPHA + 1): + res = sess.run(gdl_loss(scale_preds, scale_truths, a)) + assert res == res_tru, 'failed on alpha = %d' % a + + def test_diff_one_uniform_one_not(self): + # generate scales + scale_preds = [] + scale_truths = [] + + res_trus = np.zeros(MAX_ALPHA - 1) + for i in xrange(1, NUM_SCALES + 1): + # generate batch of 3-deep matrices with 3s on the diagonals + preds = np.empty([BATCH_SIZE, 2 ** i, 2 ** i, 3]) + imat = np.identity(2 ** i) * 3 + for elt in xrange(BATCH_SIZE): + preds[elt] = np.dstack([imat, imat, imat]) + + scale_preds.append(tf.constant(preds, dtype=tf.float32)) + scale_truths.append(tf.zeros([BATCH_SIZE, 2 ** i, 2 ** i, 3])) + + # every diff has an abs value of 3, so we can multiply that, raised to alpha + # for each alpha check, times the number of diffs in a batch: + # BATCH_SIZE * (diffs to left + down) * (diffs from up and right) * (# 3s in height) * + # (# channels) + num_diffs = BATCH_SIZE * 2 * 2 * 2**i * 3 + + for a in xrange(1, MAX_ALPHA): + res_trus[a] += num_diffs * 3**a + + for a, res_tru in enumerate(res_trus): + res = sess.run(gdl_loss(scale_preds, scale_truths, a + 1)) + assert res == res_tru, 'failed on alpha = %d' % (a + 1) + + +# noinspection PyClassHasNoInit +class TestAdvLoss: + def test_false_correct(self): + # generate scales + scale_preds = [] + targets = tf.constant(np.zeros([5, 1])) + + res_tru = 0 + log_con = np.log10(1 - 1e-7) + for i in xrange(NUM_SCALES): + scale_preds.append(1e-7 * tf.constant(np.ones([5, 1]))) + res_tru += -1 * np.sum(np.array([log_con] * 5)) + + res = sess.run(adv_loss(scale_preds, targets)) + assert np.array_equal(np.around(res, 7), np.around(res_tru, 7)) + + def test_false_incorrect(self): + scale_preds = [] + targets = tf.constant(np.zeros([5, 1])) + + res_tru = 0 + log_con = np.log10(1e-7) + for i in xrange(NUM_SCALES): + scale_preds.append(tf.constant(np.ones([5, 1])) - 1e-7) + res_tru += -1 * np.sum(np.array([log_con] * 5)) + + res = sess.run(adv_loss(scale_preds, targets)) + assert np.array_equal(np.around(res, 7), np.around(res_tru, 7)) + + def test_false_half(self): + scale_preds = [] + targets = tf.constant(np.zeros([5, 1])) + + res_tru = 0 + log_con = np.log10(0.5) + for i in xrange(NUM_SCALES): + scale_preds.append(0.5 * tf.constant(np.ones([5, 1]))) + res_tru += -1 * np.sum(np.array([log_con] * 5)) + + res = sess.run(adv_loss(scale_preds, targets)) + assert np.array_equal(np.around(res, 7), np.around(res_tru, 7)) + + def test_true_correct(self): + scale_preds = [] + targets = tf.constant(np.ones([5, 1])) + + res_tru = 0 + log = np.log10(1 - 1e-7) + for i in xrange(NUM_SCALES): + scale_preds.append(tf.constant(np.ones([5, 1])) - 1e-7) + res_tru += -1 * np.sum(np.array([log] * 5)) + + res = sess.run(adv_loss(scale_preds, targets)) + assert np.array_equal(np.around(res, 7), np.around(res_tru, 7)) + + def test_true_incorrect(self): + scale_preds = [] + targets = tf.constant(np.ones([5, 1])) + + res_tru = 0 + log = np.log10(1e-7) + for i in xrange(NUM_SCALES): + scale_preds.append(1e-7 * tf.constant(np.ones([5, 1]))) + res_tru += -1 * np.sum(np.array([log] * 5)) + + res = sess.run(adv_loss(scale_preds, targets)) + assert np.array_equal(np.around(res, 7), np.around(res_tru, 7)) + + def test_true_half(self): + scale_preds = [] + targets = tf.constant(np.ones([5, 1])) + + res_tru = 0 + log = np.log10(0.5) + for i in xrange(NUM_SCALES): + scale_preds.append(0.5 * tf.constant(np.ones([5, 1]))) + res_tru += -1 * np.sum(np.array([log] * 5)) + + res = sess.run(adv_loss(scale_preds, targets)) + assert np.array_equal(np.around(res, 7), np.around(res_tru, 7)) diff --git a/Code/process_data.py b/Code/process_data.py new file mode 100644 index 0000000..170959a --- /dev/null +++ b/Code/process_data.py @@ -0,0 +1,71 @@ +import numpy as np +import getopt +import sys +from glob import glob + +import constants as c +from utils import process_clip + + +def process_training_data(num_clips): + """ + Processes random training clips from the full training data. Saves to TRAIN_DIR_CLIPS by + default. + + @param num_clips: The number of clips to process. Default = 5000000 (set in __main__). + + @warning: This can take a couple of hours to complete with large numbers of clips. + """ + num_prev_clips = len(glob(c.TRAIN_DIR_CLIPS + '*')) + + for clip_num in xrange(num_prev_clips, num_clips + num_prev_clips): + clip = process_clip() + + np.savez_compressed(c.TRAIN_DIR_CLIPS + str(clip_num), clip) + + if (clip_num + 1) % 100 == 0: print 'Processed %d clips' % (clip_num + 1) + + +def usage(): + print 'Options:' + print '-n/--num_clips= <# clips to process for training>' + print '-t/--train_dir= <Directory of full training frames>' + print '-c/--clips_dir= <Save directory for processed clips>' + print " (I suggest making this a hidden dir so the filesystem doesn't freeze" + print " with so many files. DON'T `ls` THIS DIRECTORY!)" + print '-o/--overwrite (Overwrites the previous data in the training dir)' + + +def main(): + ## + # Handle command line input + ## + + num_clips = 5000000 + + try: + opts, _ = getopt.getopt(sys.argv[1:], 'n:t:c:o', + ['num_clips=', 'train_dir=', 'clips_dir=', 'overwrite']) + except getopt.GetoptError: + usage() + sys.exit(2) + + for opt, arg in opts: + if opt in ('-n', '--num_clips'): + num_clips = int(arg) + if opt in ('-t', '--train_dir'): + c.TRAIN_DIR = c.get_dir(arg) + if opt in ('-c', '--clips_dir'): + c.TRAIN_DIR_CLIPS = c.get_dir(arg) + if opt in ('-o', '--overwrite'): + c.clear_dir(c.TRAIN_DIR_CLIPS) + + ## + # Process data for training + ## + + process_training_data(num_clips) + + +if __name__ == '__main__': + main() diff --git a/Code/tfutils.py b/Code/tfutils.py new file mode 100644 index 0000000..22baf95 --- /dev/null +++ b/Code/tfutils.py @@ -0,0 +1,133 @@ +import tensorflow as tf +import numpy as np + + +def w(shape, stddev=0.01): + """ + @return A weight layer with the given shape and standard deviation. Initialized with a + truncated normal distribution. + """ + return tf.Variable(tf.truncated_normal(shape, stddev=stddev)) + + +def b(shape, const=0.1): + """ + @return A bias layer with the given shape. + """ + return tf.Variable(tf.constant(const, shape=shape)) + + +def conv_out_size(i, p, k, s): + """ + Gets the output size for a 2D convolution. (Assumes square input and kernel). + + @param i: The side length of the input. + @param p: The padding type (either 'SAME' or 'VALID'). + @param k: The side length of the kernel. + @param s: The stride. + + @type i: int + @type p: string + @type k: int + @type s: int + + @return The side length of the output. + """ + # convert p to a number + if p == 'SAME': + p = k // 2 + elif p == 'VALID': + p = 0 + else: + raise ValueError('p must be "SAME" or "VALID".') + + return int(((i + (2 * p) - k) / s) + 1) + + +def log10(t): + """ + Calculates the base-10 log of each element in t. + + @param t: The tensor from which to calculate the base-10 log. + + @return: A tensor with the base-10 log of each element in t. + """ + + numerator = tf.log(t) + denominator = tf.log(tf.constant(10, dtype=numerator.dtype)) + return numerator / denominator + + +def batch_pad_to_bounding_box(images, offset_height, offset_width, target_height, target_width): + """ + Zero-pads a batch of images with the given dimensions. + + @param images: 4-D tensor with shape [batch_size, height, width, channels] + @param offset_height: Number of rows of zeros to add on top. + @param offset_width: Number of columns of zeros to add on the left. + @param target_height: Height of output images. + @param target_width: Width of output images. + + @return: The batch of images, all zero-padded with the specified dimensions. + """ + batch_size, height, width, channels = tf.Session().run(tf.shape(images)) + + if not offset_height >= 0: + raise ValueError('offset_height must be >= 0') + if not offset_width >= 0: + raise ValueError('offset_width must be >= 0') + if not target_height >= height + offset_height: + raise ValueError('target_height must be >= height + offset_height') + if not target_width >= width + offset_width: + raise ValueError('target_width must be >= width + offset_width') + + num_tpad = offset_height + num_lpad = offset_width + num_bpad = target_height - (height + offset_height) + num_rpad = target_width - (width + offset_width) + + tpad = np.zeros([batch_size, num_tpad, width, channels]) + bpad = np.zeros([batch_size, num_bpad, width, channels]) + lpad = np.zeros([batch_size, target_height, num_lpad, channels]) + rpad = np.zeros([batch_size, target_height, num_rpad, channels]) + + padded = images + if num_tpad > 0 and num_bpad > 0: padded = tf.concat(1, [tpad, padded, bpad]) + elif num_tpad > 0: padded = tf.concat(1, [tpad, padded]) + elif num_bpad > 0: padded = tf.concat(1, [padded, bpad]) + if num_lpad > 0 and num_rpad > 0: padded = tf.concat(2, [lpad, padded, rpad]) + elif num_lpad > 0: padded = tf.concat(2, [lpad, padded]) + elif num_rpad > 0: padded = tf.concat(2, [padded, rpad]) + + return padded + + +def batch_crop_to_bounding_box(images, offset_height, offset_width, target_height, target_width): + """ + Crops a batch of images to the given dimensions. + + @param images: 4-D tensor with shape [batch, height, width, channels] + @param offset_height: Vertical coordinate of the top-left corner of the result in the input. + @param offset_width: Horizontal coordinate of the top-left corner of the result in the input. + @param target_height: Height of output images. + @param target_width: Width of output images. + + @return: The batch of images, all cropped the specified dimensions. + """ + batch_size, height, width, channels = tf.Session().run(tf.shape(images)) + + if not offset_height >= 0: + raise ValueError('offset_height must be >= 0') + if not offset_width >= 0: + raise ValueError('offset_width must be >= 0') + if not target_height + offset_height <= height: + raise ValueError('target_height + offset_height must be <= height') + if not target_width <= width - offset_width: + raise ValueError('target_width + offset_width must be <= width') + + top = offset_height + bottom = target_height + offset_height + left = offset_width + right = target_width + offset_width + + return images[:, top:bottom, left:right, :] diff --git a/Code/tfutils_test.py b/Code/tfutils_test.py new file mode 100644 index 0000000..4e2b490 --- /dev/null +++ b/Code/tfutils_test.py @@ -0,0 +1,102 @@ +from tfutils import * + +imgs = tf.constant(np.ones([2, 2, 2, 3])) +sess = tf.Session() + + +# noinspection PyClassHasNoInit,PyMethodMayBeStatic +class TestPad: + def test_rb(self): + res = sess.run(batch_pad_to_bounding_box(imgs, 0, 0, 4, 4)) + assert np.array_equal(res, np.array([[[[1, 1, 1], + [1, 1, 1], + [0, 0, 0], + [0, 0, 0]], + [[1, 1, 1], + [1, 1, 1], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]] + ], + [[[1, 1, 1], + [1, 1, 1], + [0, 0, 0], + [0, 0, 0]], + [[1, 1, 1], + [1, 1, 1], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]] + ]], dtype=float)) + + def test_center(self): + res = sess.run(batch_pad_to_bounding_box(imgs, 1, 1, 4, 4)) + assert np.array_equal(res, np.array([[[[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]], + [[0, 0, 0], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]] + ], + [[[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]], + [[0, 0, 0], + [1, 1, 1], + [1, 1, 1], + [0, 0, 0]], + [[0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]] + ]], dtype=float)) + + +padded = batch_pad_to_bounding_box(imgs, 1, 1, 4, 4) + + +# noinspection PyClassHasNoInit +class TestCrop: + def test_rb(self): + res = sess.run(batch_crop_to_bounding_box(padded, 0, 0, 2, 2)) + assert np.array_equal(res, np.array([[[[0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [1, 1, 1]]], + [[[0, 0, 0], + [0, 0, 0]], + [[0, 0, 0], + [1, 1, 1]]]])) + + def test_center(self): + res = sess.run(batch_crop_to_bounding_box(padded, 1, 1, 2, 2)) + assert np.array_equal(res, np.ones([2, 2, 2, 3])) diff --git a/Code/utils.py b/Code/utils.py new file mode 100644 index 0000000..2b97bdb --- /dev/null +++ b/Code/utils.py @@ -0,0 +1,212 @@ +import tensorflow as tf +import numpy as np +from scipy.ndimage import imread +from glob import glob + +import constants as c +from tfutils import log10 + +## +# Data +## + +def normalize_frames(frames): + """ + Convert frames from int8 [0, 255] to float32 [-1, 1]. + + @param frames: A numpy array. The frames to be converted. + + @return: The normalized frames. + """ + new_frames = frames.astype(np.float32) + new_frames /= (255 / 2) + new_frames -= 1 + + return new_frames + + +def denormalize_frames(frames): + """ + Performs the inverse operation of normalize_frames. + + @param frames: A numpy array. The frames to be converted. + + @return: The denormalized frames. + """ + new_frames = frames + 1 + new_frames *= (255 / 2) + # noinspection PyUnresolvedReferences + new_frames = new_frames.astype(np.uint8) + + return new_frames + +def clip_l2_diff(clip): + """ + @param clip: A numpy array of shape [c.TRAIN_HEIGHT, c.TRAIN_WIDTH, (3 * (c.HIST_LEN + 1))]. + @return: The sum of l2 differences between the frame pixels of each sequential pair of frames. + """ + diff = 0 + for i in xrange(c.HIST_LEN): + frame = clip[:, :, 3 * i:3 * (i + 1)] + next_frame = clip[:, :, 3 * (i + 1):3 * (i + 2)] + # noinspection PyTypeChecker + diff += np.sum(np.square(next_frame - frame)) + + return diff + +def get_full_clips(data_dir, num_clips, num_rec_out=1): + """ + Loads a batch of random clips from the unprocessed train or test data. + + @param data_dir: The directory of the data to read. Should be either c.TRAIN_DIR or c.TEST_DIR. + @param num_clips: The number of clips to read. + @param num_rec_out: The number of outputs to predict. Outputs > 1 are computed recursively, + using the previously-generated frames as input. Default = 1. + + @return: An array of shape + [num_clips, c.TRAIN_HEIGHT, c.TRAIN_WIDTH, (3 * (c.HIST_LEN + num_rec_out))]. + A batch of frame sequences with values normalized in range [-1, 1]. + """ + clips = np.empty([num_clips, + c.TEST_HEIGHT, + c.TEST_WIDTH, + (3 * (c.HIST_LEN + num_rec_out))]) + + # get num_clips random episodes + ep_dirs = np.random.choice(glob(data_dir + '*'), num_clips) + + # get a random clip of length HIST_LEN + 1 from each episode + for clip_num, ep_dir in enumerate(ep_dirs): + ep_frame_paths = glob(ep_dir + '/*') + start_index = np.random.choice(len(ep_frame_paths) - (c.HIST_LEN + num_rec_out - 1)) + clip_frame_paths = ep_frame_paths[start_index:start_index + (c.HIST_LEN + num_rec_out)] + + # read in frames + for frame_num, frame_path in enumerate(clip_frame_paths): + frame = imread(frame_path, mode='RGB') + norm_frame = normalize_frames(frame) + + clips[clip_num, :, :, frame_num * 3:(frame_num + 1) * 3] = norm_frame + + return clips + +def process_clip(): + """ + Gets a clip from the train dataset, cropped randomly to c.TRAIN_HEIGHT x c.TRAIN_WIDTH. + + @return: An array of shape [c.TRAIN_HEIGHT, c.TRAIN_WIDTH, (3 * (c.HIST_LEN + 1))]. + A frame sequence with values normalized in range [-1, 1]. + """ + clip = get_full_clips(c.TRAIN_DIR, 1)[0] + + # Randomly crop the clip. With 0.05 probability, take the first crop offered, otherwise, + # repeat until we have a clip with movement in it. + take_first = np.random.choice(2, p=[0.95, 0.05]) + cropped_clip = np.empty([c.TRAIN_HEIGHT, c.TRAIN_WIDTH, 3 * (c.HIST_LEN + 1)]) + for i in xrange(100): # cap at 100 trials in case the clip has no movement anywhere + crop_x = np.random.choice(c.TEST_WIDTH - c.TRAIN_WIDTH + 1) + crop_y = np.random.choice(c.TEST_HEIGHT - c.TRAIN_HEIGHT + 1) + cropped_clip = clip[crop_y:crop_y + c.TRAIN_HEIGHT, crop_x:crop_x + c.TRAIN_WIDTH, :] + + if take_first or clip_l2_diff(cropped_clip) > c.MOVEMENT_THRESHOLD: + break + + return cropped_clip + +def get_train_batch(): + """ + Loads c.BATCH_SIZE clips from the database of preprocessed training clips. + + @return: An array of shape + [c.BATCH_SIZE, c.TRAIN_HEIGHT, c.TRAIN_WIDTH, (3 * (c.HIST_LEN + 1))]. + """ + clips = np.empty([c.BATCH_SIZE, c.TRAIN_HEIGHT, c.TRAIN_WIDTH, (3 * (c.HIST_LEN + 1))], + dtype=np.float32) + for i in xrange(c.BATCH_SIZE): + path = c.TRAIN_DIR_CLIPS + str(np.random.choice(c.NUM_CLIPS)) + '.npz' + clip = np.load(path)['arr_0'] + + clips[i] = clip + + return clips + + +def get_test_batch(test_batch_size, num_rec_out=1): + """ + Gets a clip from the test dataset. + + @param test_batch_size: The number of clips. + @param num_rec_out: The number of outputs to predict. Outputs > 1 are computed recursively, + using the previously-generated frames as input. Default = 1. + + @return: An array of shape: + [test_batch_size, c.TEST_HEIGHT, c.TEST_WIDTH, (3 * (c.HIST_LEN + num_rec_out))]. + A batch of frame sequences with values normalized in range [-1, 1]. + """ + return get_full_clips(c.TEST_DIR, test_batch_size, num_rec_out=num_rec_out) + + +## +# Error calculation +## + +# TODO: Add SSIM error http://www.cns.nyu.edu/pub/eero/wang03-reprint.pdf +# TODO: Unit test error functions. + +def psnr_error(gen_frames, gt_frames): + """ + Computes the Peak Signal to Noise Ratio error between the generated images and the ground + truth images. + + @param gen_frames: A tensor of shape [batch_size, height, width, 3]. The frames generated by the + generator model. + @param gt_frames: A tensor of shape [batch_size, height, width, 3]. The ground-truth frames for + each frame in gen_frames. + + @return: A scalar tensor. The mean Peak Signal to Noise Ratio error over each frame in the + batch. + """ + shape = tf.shape(gen_frames) + num_pixels = tf.to_float(shape[1] * shape[2]) + square_diff = tf.square(gt_frames - gen_frames) + + batch_errors = 10 * log10(1 / ((1 / num_pixels) * tf.reduce_sum(square_diff, [1, 2, 3]))) + return tf.reduce_mean(batch_errors) + +def sharp_diff_error(gen_frames, gt_frames): + """ + Computes the Sharpness Difference error between the generated images and the ground truth + images. + + @param gen_frames: A tensor of shape [batch_size, height, width, 3]. The frames generated by the + generator model. + @param gt_frames: A tensor of shape [batch_size, height, width, 3]. The ground-truth frames for + each frame in gen_frames. + + @return: A scalar tensor. The Sharpness Difference error over each frame in the batch. + """ + shape = tf.shape(gen_frames) + num_pixels = tf.to_float(shape[1] * shape[2]) + + # gradient difference + # create filters [-1, 1] and [[1],[-1]] for diffing to the left and down respectively. + # TODO: Could this be simplified with one filter [[-1, 2], [0, -1]]? + pos = tf.constant(np.identity(3), dtype=tf.float32) + neg = -1 * pos + filter_x = tf.expand_dims(tf.pack([neg, pos]), 0) # [-1, 1] + filter_y = tf.pack([tf.expand_dims(pos, 0), tf.expand_dims(neg, 0)]) # [[1],[-1]] + strides = [1, 1, 1, 1] # stride of (1, 1) + padding = 'SAME' + + gen_dx = tf.abs(tf.nn.conv2d(gen_frames, filter_x, strides, padding=padding)) + gen_dy = tf.abs(tf.nn.conv2d(gen_frames, filter_y, strides, padding=padding)) + gt_dx = tf.abs(tf.nn.conv2d(gt_frames, filter_x, strides, padding=padding)) + gt_dy = tf.abs(tf.nn.conv2d(gt_frames, filter_y, strides, padding=padding)) + + gen_grad_sum = gen_dx + gen_dy + gt_grad_sum = gt_dx + gt_dy + + grad_diff = tf.abs(gt_grad_sum - gen_grad_sum) + + batch_errors = 10 * log10(1 / ((1 / num_pixels) * tf.reduce_sum(grad_diff, [1, 2, 3]))) + return tf.reduce_mean(batch_errors) |
