diff options
Diffstat (limited to 'Code/avg_runner.py')
| -rw-r--r-- | Code/avg_runner.py | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/Code/avg_runner.py b/Code/avg_runner.py index 6809187..ed72b63 100644 --- a/Code/avg_runner.py +++ b/Code/avg_runner.py @@ -3,7 +3,11 @@ import getopt import sys import os -from utils import get_train_batch, get_test_batch +import pprint +pp = pprint.PrettyPrinter(indent=2) + +from glob import glob +from utils import get_train_batch, get_test_batch, get_all_clips import constants as c from g_model import GeneratorModel from d_model import DiscriminatorModel @@ -26,8 +30,11 @@ class AVGRunner: self.num_steps = num_steps self.num_test_rec = num_test_rec - self.sess = tf.Session() - self.summary_writer = tf.train.SummaryWriter(c.SUMMARY_SAVE_DIR, graph=self.sess.graph) + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + self.sess = tf.Session(config = config) + #self.sess = tf.Session() + self.summary_writer = tf.summary.FileWriter(c.SUMMARY_SAVE_DIR, graph=self.sess.graph) if c.ADVERSARIAL: print 'Init discriminator...' @@ -97,11 +104,20 @@ class AVGRunner: self.g_model.test_batch( batch, self.global_step, num_rec_out=self.num_test_rec) + def process(self): + """ + Process a directory of images using the generator network. + """ + batch = get_all_clips(c.PROCESS_DIR) + self.g_model.test_batch( + batch, self.global_step, num_rec_out=self.num_test_rec, process_only=True) + def usage(): print 'Options:' print '-l/--load_path= <Relative/path/to/saved/model>' print '-t/--test_dir= <Directory of test images>' + print '-p/--process_dir= <Directory to process>' print '-r/--recursions= <# recursive predictions to make on test>' print '-a/--adversarial= <{t/f}> (Whether to use adversarial training. Default=True)' print '-n/--name= <Subdirectory of ../Data/Save/*/ in which to save output of this run>' @@ -123,6 +139,7 @@ def main(): load_path = None test_only = False + process_only = False num_test_rec = 1 # number of recursive predictions to make on test num_steps = 1000001 try: @@ -140,6 +157,9 @@ def main(): load_path = arg if opt in ('-t', '--test_dir'): c.set_test_dir(arg) + if opt in ('-p', '--process_dir'): + c.set_process_dir(arg) + process_only = True if opt in ('-r', '--recursions'): num_test_rec = int(arg) if opt in ('-a', '--adversarial'): @@ -177,6 +197,8 @@ def main(): runner = AVGRunner(num_steps, load_path, num_test_rec) if test_only: runner.test() + elif process_only: + runner.process() else: runner.train() |
