From 566bd785e607b41d37120a20f236740944f00e84 Mon Sep 17 00:00:00 2001 From: jules on spawn Date: Wed, 16 May 2018 04:38:31 +0200 Subject: fix trainer --- train.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index 974d8d8..23db492 100644 --- a/train.py +++ b/train.py @@ -51,7 +51,7 @@ default_params = { 'sample_length': 80000, 'loss_smoothing': 0.99, 'cuda': True, - 'comet_key': None + 'comet_key': None, 'primer': '', } @@ -219,7 +219,7 @@ def main(exp, frame_sizes, dataset, **params): trainer.register_plugin(GeneratorPlugin( os.path.join(results_path, 'samples'), params['n_samples'], params['sample_length'], params['sample_rate'], - params['primer'] + params['primer'], 0, 0, 0 )) trainer.register_plugin( Logger([ @@ -359,6 +359,15 @@ if __name__ == '__main__': parser.add_argument( '--primer', help='prime the generator...' ) + parser.add_argument( + '--primer_a', help='prime the generator...' + ) + parser.add_argument( + '--primer_b', help='prime the generator...' + ) + parser.add_argument( + '--recursive', help='prime the generator...' + ) parser.set_defaults(**default_params) -- cgit v1.2.3-70-g09d2