summaryrefslogtreecommitdiff
path: root/Code/d_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'Code/d_model.py')
-rw-r--r--Code/d_model.py187
1 files changed, 187 insertions, 0 deletions
diff --git a/Code/d_model.py b/Code/d_model.py
new file mode 100644
index 0000000..7b1cb12
--- /dev/null
+++ b/Code/d_model.py
@@ -0,0 +1,187 @@
+import tensorflow as tf
+import numpy as np
+from skimage.transform import resize
+
+from d_scale_model import DScaleModel
+from loss_functions import adv_loss
+import constants as c
+
+
+# noinspection PyShadowingNames
+class DiscriminatorModel:
+ def __init__(self, session, summary_writer, height, width, scale_conv_layer_fms,
+ scale_kernel_sizes, scale_fc_layer_sizes):
+ """
+ Initializes a GeneratorModel.
+
+ @param session: The TensorFlow session.
+ @param summary_writer: The writer object to record TensorBoard summaries
+ @param height: The height of the input images.
+ @param width: The width of the input images.
+ @param scale_conv_layer_fms: The number of feature maps in each convolutional layer of each
+ scale network.
+ @param scale_kernel_sizes: The size of the kernel for each layer of each scale network.
+ @param scale_fc_layer_sizes: The number of nodes in each fully-connected layer of each scale
+ network.
+
+ @type session: tf.Session
+ @type summary_writer: tf.train.SummaryWriter
+ @type height: int
+ @type width: int
+ @type scale_conv_layer_fms: list<list<int>>
+ @type scale_kernel_sizes: list<list<int>>
+ @type scale_fc_layer_sizes: list<list<int>>
+ """
+ self.sess = session
+ self.summary_writer = summary_writer
+ self.height = height
+ self.width = width
+ self.scale_conv_layer_fms = scale_conv_layer_fms
+ self.scale_kernel_sizes = scale_kernel_sizes
+ self.scale_fc_layer_sizes = scale_fc_layer_sizes
+ self.num_scale_nets = len(scale_conv_layer_fms)
+
+ self.define_graph()
+
+ # noinspection PyAttributeOutsideInit
+ def define_graph(self):
+ """
+ Sets up the model graph in TensorFlow.
+ """
+ with tf.name_scope('discriminator'):
+ ##
+ # Setup scale networks. Each will make the predictions for images at a given scale.
+ ##
+
+ self.scale_nets = []
+ for scale_num in xrange(self.num_scale_nets):
+ with tf.name_scope('scale_net_' + str(scale_num)):
+ scale_factor = 1. / 2 ** ((self.num_scale_nets - 1) - scale_num)
+ self.scale_nets.append(DScaleModel(scale_num,
+ int(self.height * scale_factor),
+ int(self.width * scale_factor),
+ self.scale_conv_layer_fms[scale_num],
+ self.scale_kernel_sizes[scale_num],
+ self.scale_fc_layer_sizes[scale_num]))
+
+ # A list of the prediction tensors for each scale network
+ self.scale_preds = []
+ for scale_num in xrange(self.num_scale_nets):
+ self.scale_preds.append(self.scale_nets[scale_num].preds)
+
+ ##
+ # Data
+ ##
+
+ self.labels = tf.placeholder(tf.float32, shape=[None, 1], name='labels')
+
+ ##
+ # Training
+ ##
+
+ with tf.name_scope('training'):
+ # global loss is the combined loss from every scale network
+ self.global_loss = adv_loss(self.scale_preds, self.labels)
+ self.global_step = tf.Variable(0, trainable=False, name='global_step')
+ self.optimizer = tf.train.GradientDescentOptimizer(c.LRATE_D, name='optimizer')
+ self.train_op = self.optimizer.minimize(self.global_loss,
+ global_step=self.global_step,
+ name='train_op')
+
+ # add summaries to visualize in TensorBoard
+ loss_summary = tf.scalar_summary('loss_D', self.global_loss)
+ self.summaries = tf.merge_summary([loss_summary])
+
+ def build_feed_dict(self, input_frames, gt_output_frames, generator):
+ """
+ Builds a feed_dict with resized inputs and outputs for each scale network.
+
+ @param input_frames: An array of shape
+ [batch_size x self.height x self.width x (3 * HIST_LEN)], The frames to
+ use for generation.
+ @param gt_output_frames: An array of shape [batch_size x self.height x self.width x 3], The
+ ground truth outputs for each sequence in input_frames.
+ @param generator: The generator model.
+
+ @return: The feed_dict needed to run this network, all scale_nets, and the generator
+ predictions.
+ """
+ feed_dict = {}
+ batch_size = np.shape(gt_output_frames)[0]
+
+ ##
+ # Get generated frames from GeneratorModel
+ ##
+
+ g_feed_dict = {generator.input_frames_train: input_frames,
+ generator.gt_frames_train: gt_output_frames}
+ g_scale_preds = self.sess.run(generator.scale_preds_train, feed_dict=g_feed_dict)
+
+ ##
+ # Create discriminator feed dict
+ ##
+ for scale_num in xrange(self.num_scale_nets):
+ scale_net = self.scale_nets[scale_num]
+
+ # resize gt_output_frames
+ scaled_gt_output_frames = np.empty([batch_size, scale_net.height, scale_net.width, 3])
+ for i, img in enumerate(gt_output_frames):
+ # for skimage.transform.resize, images need to be in range [0, 1], so normalize to
+ # [0, 1] before resize and back to [-1, 1] after
+ sknorm_img = (img / 2) + 0.5
+ resized_frame = resize(sknorm_img, [scale_net.height, scale_net.width, 3])
+ scaled_gt_output_frames[i] = (resized_frame - 0.5) * 2
+
+ # combine with resized gt_output_frames to get inputs for prediction
+ scaled_input_frames = np.concatenate([g_scale_preds[scale_num],
+ scaled_gt_output_frames])
+
+ # convert to np array and add to feed_dict
+ feed_dict[scale_net.input_frames] = scaled_input_frames
+
+ # add labels for each image to feed_dict
+ batch_size = np.shape(input_frames)[0]
+ feed_dict[self.labels] = np.concatenate([np.zeros([batch_size, 1]),
+ np.ones([batch_size, 1])])
+
+ return feed_dict
+
+ def train_step(self, batch, generator):
+ """
+ Runs a training step using the global loss on each of the scale networks.
+
+ @param batch: An array of shape
+ [BATCH_SIZE x self.height x self.width x (3 * (HIST_LEN + 1))]. The input
+ and output frames, concatenated along the channel axis (index 3).
+ @param generator: The generator model.
+
+ @return: The global step.
+ """
+ ##
+ # Split into inputs and outputs
+ ##
+
+ input_frames = batch[:, :, :, :-3]
+ gt_output_frames = batch[:, :, :, -3:]
+
+ ##
+ # Train
+ ##
+
+ feed_dict = self.build_feed_dict(input_frames, gt_output_frames, generator)
+
+ _, global_loss, global_step, summaries = self.sess.run(
+ [self.train_op, self.global_loss, self.global_step, self.summaries],
+ feed_dict=feed_dict)
+
+ ##
+ # User output
+ ##
+
+ if global_step % c.STATS_FREQ == 0:
+ print 'DiscriminatorModel: step %d | global loss: %f' % (global_step, global_loss)
+ if global_step % c.SUMMARY_FREQ == 0:
+ print 'DiscriminatorModel: saved summaries'
+ self.summary_writer.add_summary(summaries, global_step)
+
+ return global_step