summaryrefslogtreecommitdiff
path: root/Code/g_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'Code/g_model.py')
-rw-r--r--Code/g_model.py35
1 files changed, 23 insertions, 12 deletions
diff --git a/Code/g_model.py b/Code/g_model.py
index 5dc8265..8ff75a1 100644
--- a/Code/g_model.py
+++ b/Code/g_model.py
@@ -114,7 +114,7 @@ class GeneratorModel:
if scale_num > 0:
last_gen_frames = tf.image.resize_images(
last_gen_frames,[scale_height, scale_width])
- inputs = tf.concat(3, [inputs, last_gen_frames])
+ inputs = tf.concat(axis=3, values=[inputs, last_gen_frames])
# generated frame predictions
preds = inputs
@@ -196,7 +196,7 @@ class GeneratorModel:
name='train_op')
# train loss summary
- loss_summary = tf.scalar_summary('train_loss_G', self.global_loss)
+ loss_summary = tf.summary.scalar('train_loss_G', self.global_loss)
self.summaries_train.append(loss_summary)
##
@@ -215,22 +215,22 @@ class GeneratorModel:
self.sharpdiff_error_test = sharp_diff_error(self.scale_preds_test[-1],
self.gt_frames_test)
# train error summaries
- summary_psnr_train = tf.scalar_summary('train_PSNR',
+ summary_psnr_train = tf.summary.scalar('train_PSNR',
self.psnr_error_train)
- summary_sharpdiff_train = tf.scalar_summary('train_SharpDiff',
+ summary_sharpdiff_train = tf.summary.scalar('train_SharpDiff',
self.sharpdiff_error_train)
self.summaries_train += [summary_psnr_train, summary_sharpdiff_train]
# test error
- summary_psnr_test = tf.scalar_summary('test_PSNR',
+ summary_psnr_test = tf.summary.scalar('test_PSNR',
self.psnr_error_test)
- summary_sharpdiff_test = tf.scalar_summary('test_SharpDiff',
+ summary_sharpdiff_test = tf.summary.scalar('test_SharpDiff',
self.sharpdiff_error_test)
self.summaries_test += [summary_psnr_test, summary_sharpdiff_test]
# add summaries to visualize in TensorBoard
- self.summaries_train = tf.merge_summary(self.summaries_train)
- self.summaries_test = tf.merge_summary(self.summaries_test)
+ self.summaries_train = tf.summary.merge(self.summaries_train)
+ self.summaries_test = tf.summary.merge(self.summaries_test)
def train_step(self, batch, discriminator=None):
"""
@@ -307,13 +307,17 @@ class GeneratorModel:
scale_width = int(self.width_train * scale_factor)
# resize gt_output_frames for scale and append to scale_gts_train
+ broken = 0
scaled_gt_frames = np.empty([c.BATCH_SIZE, scale_height, scale_width, 3])
for i, img in enumerate(gt_frames):
# for skimage.transform.resize, images need to be in range [0, 1], so normalize
# to [0, 1] before resize and back to [-1, 1] after
sknorm_img = (img / 2) + 0.5
- resized_frame = resize(sknorm_img, [scale_height, scale_width, 3])
- scaled_gt_frames[i] = (resized_frame - 0.5) * 2
+ try:
+ resized_frame = resize(sknorm_img, [scale_height, scale_width, 3])
+ scaled_gt_frames[i-broken] = (resized_frame - 0.5) * 2
+ except:
+ broken += 1
scale_gts.append(scaled_gt_frames)
# for every clip in the batch, save the inputs, scale preds and scale gts
@@ -342,7 +346,7 @@ class GeneratorModel:
return global_step
- def test_batch(self, batch, global_step, num_rec_out=1, save_imgs=True):
+ def test_batch(self, batch, global_step, num_rec_out=1, save_imgs=True, process_only=False):
"""
Runs a training step using the global loss on each of the scale networks.
@@ -408,7 +412,14 @@ class GeneratorModel:
# Save images
##
- if save_imgs:
+ if process_only:
+ pred_dir = c.get_dir(os.path.join(
+ c.IMG_SAVE_DIR, 'Process/Step_' + str(global_step)))
+ for rec_num in xrange(num_rec_out):
+ gen_img = rec_preds[rec_num][0]
+ imsave(os.path.join(pred_dir, 'gen_{:05d}.png'.format(rec_num)), gen_img)
+
+ elif save_imgs:
for pred_num in xrange(len(input_frames)):
pred_dir = c.get_dir(os.path.join(
c.IMG_SAVE_DIR, 'Tests/Step_' + str(global_step), str(pred_num)))