diff options
| author | Matt Cooper <matthew_cooper@brown.edu> | 2016-10-09 18:23:15 -0500 |
|---|---|---|
| committer | Matt Cooper <matthew_cooper@brown.edu> | 2016-10-09 18:23:15 -0500 |
| commit | 9c6b43a967227f7bea591f1647b8dff9953cd7ad (patch) | |
| tree | b74762c8076be0870590f036e55e53fc65ce3f61 /Code/avg_runner.py | |
| parent | bc26ec76145ad70eb2ac9364f52695a26a9bc68a (diff) | |
added option for stopping training after a certain number of iterations
Diffstat (limited to 'Code/avg_runner.py')
| -rw-r--r-- | Code/avg_runner.py | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/Code/avg_runner.py b/Code/avg_runner.py index 9a27426..959ff90 100644 --- a/Code/avg_runner.py +++ b/Code/avg_runner.py @@ -10,10 +10,11 @@ from d_model import DiscriminatorModel class AVGRunner: - def __init__(self, model_load_path, num_test_rec): + def __init__(self, num_steps, model_load_path, num_test_rec): """ Initializes the Adversarial Video Generation Runner. + @param num_steps: The number of training steps to run. @param model_load_path: The path from which to load a previously-saved model. Default = None. @param num_test_rec: The number of recursive generations to produce when testing. Recursive @@ -22,6 +23,7 @@ class AVGRunner: """ self.global_step = 0 + self.num_steps = num_steps self.num_test_rec = num_test_rec self.sess = tf.Session() @@ -60,7 +62,7 @@ class AVGRunner: """ Runs a training loop on the model networks. """ - while True: + for i in xrange(self.num_steps): if c.ADVERSARIAL: # update discriminator batch = get_train_batch() @@ -103,6 +105,7 @@ def usage(): print '-r/--recursions= <# recursive predictions to make on test>' print '-a/--adversarial= <{t/f}> (Whether to use adversarial training. Default=True)' print '-n/--name= <Subdirectory of ../Data/Save/*/ in which to save output of this run>' + print '-s/--steps= <Number of training steps to run> (Default=1000001)' print '-O/--overwrite (Overwrites all previous data for the model with this save name)' print '-T/--test_only (Only runs a test step -- no training)' print '-H/--help (Prints usage)' @@ -121,11 +124,13 @@ def main(): load_path = None test_only = False num_test_rec = 1 # number of recursive predictions to make on test + num_steps = 1000001 try: - opts, _ = getopt.getopt(sys.argv[1:], 'l:t:r:a:n:OTH', + opts, _ = getopt.getopt(sys.argv[1:], 'l:t:r:a:n:s:OTH', ['load_path=', 'test_dir=', 'recursions=', 'adversarial=', 'name=', - 'overwrite', 'test_only', 'help', 'stats_freq=', 'summary_freq=', - 'img_save_freq=', 'test_freq=', 'model_save_freq=']) + 'steps=', 'overwrite', 'test_only', 'help', 'stats_freq=', + 'summary_freq=', 'img_save_freq=', 'test_freq=', + 'model_save_freq=']) except getopt.GetoptError: usage() sys.exit(2) @@ -141,6 +146,8 @@ def main(): c.ADVERSARIAL = (arg.lower() == 'true' or arg.lower() == 't') if opt in ('-n', '--name'): c.set_save_name(arg) + if opt in ('-s', '--steps'): + num_steps = int(arg) if opt in ('-O', '--overwrite'): c.clear_save_name() if opt in ('-H', '--help'): @@ -167,7 +174,7 @@ def main(): # Init and run the predictor ## - runner = AVGRunner(load_path, num_test_rec) + runner = AVGRunner(num_steps, load_path, num_test_rec) if test_only: runner.test() else: |
