summaryrefslogtreecommitdiff
path: root/Code/avg_runner.py
diff options
context:
space:
mode:
authorjules@lens <julescarbon@gmail.com>2018-04-24 20:19:24 +0200
committerjules@lens <julescarbon@gmail.com>2018-04-24 20:19:24 +0200
commit83e91e0a8effcd20466e56b6ecc3e349bbfa5e0e (patch)
tree377977c1068fa2411ac6b0e7c6da4aa97873ae8b /Code/avg_runner.py
parent9b0d10f357871231bbec06c610363588812216e1 (diff)
updates to tensorflow code and processing additionsHEADmaster
Diffstat (limited to 'Code/avg_runner.py')
-rw-r--r--Code/avg_runner.py28
1 files changed, 25 insertions, 3 deletions
diff --git a/Code/avg_runner.py b/Code/avg_runner.py
index 6809187..ed72b63 100644
--- a/Code/avg_runner.py
+++ b/Code/avg_runner.py
@@ -3,7 +3,11 @@ import getopt
import sys
import os
-from utils import get_train_batch, get_test_batch
+import pprint
+pp = pprint.PrettyPrinter(indent=2)
+
+from glob import glob
+from utils import get_train_batch, get_test_batch, get_all_clips
import constants as c
from g_model import GeneratorModel
from d_model import DiscriminatorModel
@@ -26,8 +30,11 @@ class AVGRunner:
self.num_steps = num_steps
self.num_test_rec = num_test_rec
- self.sess = tf.Session()
- self.summary_writer = tf.train.SummaryWriter(c.SUMMARY_SAVE_DIR, graph=self.sess.graph)
+ config = tf.ConfigProto()
+ config.gpu_options.allow_growth = True
+ self.sess = tf.Session(config = config)
+ #self.sess = tf.Session()
+ self.summary_writer = tf.summary.FileWriter(c.SUMMARY_SAVE_DIR, graph=self.sess.graph)
if c.ADVERSARIAL:
print 'Init discriminator...'
@@ -97,11 +104,20 @@ class AVGRunner:
self.g_model.test_batch(
batch, self.global_step, num_rec_out=self.num_test_rec)
+ def process(self):
+ """
+ Process a directory of images using the generator network.
+ """
+ batch = get_all_clips(c.PROCESS_DIR)
+ self.g_model.test_batch(
+ batch, self.global_step, num_rec_out=self.num_test_rec, process_only=True)
+
def usage():
print 'Options:'
print '-l/--load_path= <Relative/path/to/saved/model>'
print '-t/--test_dir= <Directory of test images>'
+ print '-p/--process_dir= <Directory to process>'
print '-r/--recursions= <# recursive predictions to make on test>'
print '-a/--adversarial= <{t/f}> (Whether to use adversarial training. Default=True)'
print '-n/--name= <Subdirectory of ../Data/Save/*/ in which to save output of this run>'
@@ -123,6 +139,7 @@ def main():
load_path = None
test_only = False
+ process_only = False
num_test_rec = 1 # number of recursive predictions to make on test
num_steps = 1000001
try:
@@ -140,6 +157,9 @@ def main():
load_path = arg
if opt in ('-t', '--test_dir'):
c.set_test_dir(arg)
+ if opt in ('-p', '--process_dir'):
+ c.set_process_dir(arg)
+ process_only = True
if opt in ('-r', '--recursions'):
num_test_rec = int(arg)
if opt in ('-a', '--adversarial'):
@@ -177,6 +197,8 @@ def main():
runner = AVGRunner(num_steps, load_path, num_test_rec)
if test_only:
runner.test()
+ elif process_only:
+ runner.process()
else:
runner.train()