From 566bd785e607b41d37120a20f236740944f00e84 Mon Sep 17 00:00:00 2001 From: jules on spawn Date: Wed, 16 May 2018 04:38:31 +0200 Subject: fix trainer --- train.py | 13 +++++++++++-- trainer/plugins.py | 5 ++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 974d8d8..23db492 100644 --- a/train.py +++ b/train.py @@ -51,7 +51,7 @@ default_params = { 'sample_length': 80000, 'loss_smoothing': 0.99, 'cuda': True, - 'comet_key': None + 'comet_key': None, 'primer': '', } @@ -219,7 +219,7 @@ def main(exp, frame_sizes, dataset, **params): trainer.register_plugin(GeneratorPlugin( os.path.join(results_path, 'samples'), params['n_samples'], params['sample_length'], params['sample_rate'], - params['primer'] + params['primer'], 0, 0, 0 )) trainer.register_plugin( Logger([ @@ -359,6 +359,15 @@ if __name__ == '__main__': parser.add_argument( '--primer', help='prime the generator...' ) + parser.add_argument( + '--primer_a', help='prime the generator...' + ) + parser.add_argument( + '--primer_b', help='prime the generator...' + ) + parser.add_argument( + '--recursive', help='prime the generator...' + ) parser.set_defaults(**default_params) diff --git a/trainer/plugins.py b/trainer/plugins.py index 132c33d..5db48ab 100644 --- a/trainer/plugins.py +++ b/trainer/plugins.py @@ -161,7 +161,10 @@ class GeneratorPlugin(Plugin): self.generate = PrimedGenerator(trainer.model.model, trainer.cuda) def epoch(self, epoch_index): - samples = self.generate(self.n_samples, self.sample_length, self.primer, self.primer_a, self.primer_b, self.recursive) \ + if self.primer == "": + samples = self.generate(self.n_samples, self.sample_length).cpu().float().numpy() + else: + samples = self.generate(self.n_samples, self.sample_length, self.primer, self.primer_a, self.primer_b, self.recursive) \ .cpu().float().numpy() for i in range(self.n_samples): write_wav( -- cgit v1.2.3-70-g09d2