summaryrefslogtreecommitdiff
path: root/Code
diff options
context:
space:
mode:
authorMatt Cooper <matthew_cooper@brown.edu>2016-10-09 18:23:15 -0500
committerMatt Cooper <matthew_cooper@brown.edu>2016-10-09 18:23:15 -0500
commit9c6b43a967227f7bea591f1647b8dff9953cd7ad (patch)
treeb74762c8076be0870590f036e55e53fc65ce3f61 /Code
parentbc26ec76145ad70eb2ac9364f52695a26a9bc68a (diff)
added option for stopping training after a certain number of iterations
Diffstat (limited to 'Code')
-rw-r--r--Code/avg_runner.py19
1 files changed, 13 insertions, 6 deletions
diff --git a/Code/avg_runner.py b/Code/avg_runner.py
index 9a27426..959ff90 100644
--- a/Code/avg_runner.py
+++ b/Code/avg_runner.py
@@ -10,10 +10,11 @@ from d_model import DiscriminatorModel
class AVGRunner:
- def __init__(self, model_load_path, num_test_rec):
+ def __init__(self, num_steps, model_load_path, num_test_rec):
"""
Initializes the Adversarial Video Generation Runner.
+ @param num_steps: The number of training steps to run.
@param model_load_path: The path from which to load a previously-saved model.
Default = None.
@param num_test_rec: The number of recursive generations to produce when testing. Recursive
@@ -22,6 +23,7 @@ class AVGRunner:
"""
self.global_step = 0
+ self.num_steps = num_steps
self.num_test_rec = num_test_rec
self.sess = tf.Session()
@@ -60,7 +62,7 @@ class AVGRunner:
"""
Runs a training loop on the model networks.
"""
- while True:
+ for i in xrange(self.num_steps):
if c.ADVERSARIAL:
# update discriminator
batch = get_train_batch()
@@ -103,6 +105,7 @@ def usage():
print '-r/--recursions= <# recursive predictions to make on test>'
print '-a/--adversarial= <{t/f}> (Whether to use adversarial training. Default=True)'
print '-n/--name= <Subdirectory of ../Data/Save/*/ in which to save output of this run>'
+ print '-s/--steps= <Number of training steps to run> (Default=1000001)'
print '-O/--overwrite (Overwrites all previous data for the model with this save name)'
print '-T/--test_only (Only runs a test step -- no training)'
print '-H/--help (Prints usage)'
@@ -121,11 +124,13 @@ def main():
load_path = None
test_only = False
num_test_rec = 1 # number of recursive predictions to make on test
+ num_steps = 1000001
try:
- opts, _ = getopt.getopt(sys.argv[1:], 'l:t:r:a:n:OTH',
+ opts, _ = getopt.getopt(sys.argv[1:], 'l:t:r:a:n:s:OTH',
['load_path=', 'test_dir=', 'recursions=', 'adversarial=', 'name=',
- 'overwrite', 'test_only', 'help', 'stats_freq=', 'summary_freq=',
- 'img_save_freq=', 'test_freq=', 'model_save_freq='])
+ 'steps=', 'overwrite', 'test_only', 'help', 'stats_freq=',
+ 'summary_freq=', 'img_save_freq=', 'test_freq=',
+ 'model_save_freq='])
except getopt.GetoptError:
usage()
sys.exit(2)
@@ -141,6 +146,8 @@ def main():
c.ADVERSARIAL = (arg.lower() == 'true' or arg.lower() == 't')
if opt in ('-n', '--name'):
c.set_save_name(arg)
+ if opt in ('-s', '--steps'):
+ num_steps = int(arg)
if opt in ('-O', '--overwrite'):
c.clear_save_name()
if opt in ('-H', '--help'):
@@ -167,7 +174,7 @@ def main():
# Init and run the predictor
##
- runner = AVGRunner(load_path, num_test_rec)
+ runner = AVGRunner(num_steps, load_path, num_test_rec)
if test_only:
runner.test()
else: