diff options
Diffstat (limited to 'Code/g_model.py')
| -rw-r--r-- | Code/g_model.py | 35 |
1 files changed, 23 insertions, 12 deletions
diff --git a/Code/g_model.py b/Code/g_model.py index 5dc8265..8ff75a1 100644 --- a/Code/g_model.py +++ b/Code/g_model.py @@ -114,7 +114,7 @@ class GeneratorModel: if scale_num > 0: last_gen_frames = tf.image.resize_images( last_gen_frames,[scale_height, scale_width]) - inputs = tf.concat(3, [inputs, last_gen_frames]) + inputs = tf.concat(axis=3, values=[inputs, last_gen_frames]) # generated frame predictions preds = inputs @@ -196,7 +196,7 @@ class GeneratorModel: name='train_op') # train loss summary - loss_summary = tf.scalar_summary('train_loss_G', self.global_loss) + loss_summary = tf.summary.scalar('train_loss_G', self.global_loss) self.summaries_train.append(loss_summary) ## @@ -215,22 +215,22 @@ class GeneratorModel: self.sharpdiff_error_test = sharp_diff_error(self.scale_preds_test[-1], self.gt_frames_test) # train error summaries - summary_psnr_train = tf.scalar_summary('train_PSNR', + summary_psnr_train = tf.summary.scalar('train_PSNR', self.psnr_error_train) - summary_sharpdiff_train = tf.scalar_summary('train_SharpDiff', + summary_sharpdiff_train = tf.summary.scalar('train_SharpDiff', self.sharpdiff_error_train) self.summaries_train += [summary_psnr_train, summary_sharpdiff_train] # test error - summary_psnr_test = tf.scalar_summary('test_PSNR', + summary_psnr_test = tf.summary.scalar('test_PSNR', self.psnr_error_test) - summary_sharpdiff_test = tf.scalar_summary('test_SharpDiff', + summary_sharpdiff_test = tf.summary.scalar('test_SharpDiff', self.sharpdiff_error_test) self.summaries_test += [summary_psnr_test, summary_sharpdiff_test] # add summaries to visualize in TensorBoard - self.summaries_train = tf.merge_summary(self.summaries_train) - self.summaries_test = tf.merge_summary(self.summaries_test) + self.summaries_train = tf.summary.merge(self.summaries_train) + self.summaries_test = tf.summary.merge(self.summaries_test) def train_step(self, batch, discriminator=None): """ @@ -307,13 +307,17 @@ class GeneratorModel: scale_width = int(self.width_train * scale_factor) # resize gt_output_frames for scale and append to scale_gts_train + broken = 0 scaled_gt_frames = np.empty([c.BATCH_SIZE, scale_height, scale_width, 3]) for i, img in enumerate(gt_frames): # for skimage.transform.resize, images need to be in range [0, 1], so normalize # to [0, 1] before resize and back to [-1, 1] after sknorm_img = (img / 2) + 0.5 - resized_frame = resize(sknorm_img, [scale_height, scale_width, 3]) - scaled_gt_frames[i] = (resized_frame - 0.5) * 2 + try: + resized_frame = resize(sknorm_img, [scale_height, scale_width, 3]) + scaled_gt_frames[i-broken] = (resized_frame - 0.5) * 2 + except: + broken += 1 scale_gts.append(scaled_gt_frames) # for every clip in the batch, save the inputs, scale preds and scale gts @@ -342,7 +346,7 @@ class GeneratorModel: return global_step - def test_batch(self, batch, global_step, num_rec_out=1, save_imgs=True): + def test_batch(self, batch, global_step, num_rec_out=1, save_imgs=True, process_only=False): """ Runs a training step using the global loss on each of the scale networks. @@ -408,7 +412,14 @@ class GeneratorModel: # Save images ## - if save_imgs: + if process_only: + pred_dir = c.get_dir(os.path.join( + c.IMG_SAVE_DIR, 'Process/Step_' + str(global_step))) + for rec_num in xrange(num_rec_out): + gen_img = rec_preds[rec_num][0] + imsave(os.path.join(pred_dir, 'gen_{:05d}.png'.format(rec_num)), gen_img) + + elif save_imgs: for pred_num in xrange(len(input_frames)): pred_dir = c.get_dir(os.path.join( c.IMG_SAVE_DIR, 'Tests/Step_' + str(global_step), str(pred_num))) |
