From 88eecd62b75cc032752aa10121d376cc7bca418b Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Mon, 14 May 2018 19:20:31 +0200 Subject: flag to prime the generator --- .gitignore | 2 ++ generate.py | 9 +++++++-- train.py | 6 +++++- trainer/plugins.py | 5 +++-- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 26a1d76..2c235f0 100644 --- a/.gitignore +++ b/.gitignore @@ -113,3 +113,5 @@ results/old_checkpoints/ mgb run_* +.DS_Store + diff --git a/generate.py b/generate.py index 92a930f..fb63b09 100644 --- a/generate.py +++ b/generate.py @@ -51,7 +51,8 @@ default_params = { 'sample_length': 80000, 'loss_smoothing': 0.99, 'cuda': True, - 'comet_key': None + 'comet_key': None, + 'primer': 'zero' } tag_params = [ @@ -222,7 +223,8 @@ def main(exp, frame_sizes, dataset, **params): """ trainer.register_plugin(GeneratorPlugin( os.path.join(results_path, 'samples'), params['n_samples'], - params['sample_length'], params['sample_rate'] + params['sample_length'], params['sample_rate'], + params['primer'] )) """ trainer.register_plugin( @@ -359,6 +361,9 @@ if __name__ == '__main__': parser.add_argument( '--comet_key', help='comet.ml API key' ) + parser.add_argument( + '--primer', help='prime the generator...' + ) parser.set_defaults(**default_params) diff --git a/train.py b/train.py index 1f4ab3b..a40e5f6 100644 --- a/train.py +++ b/train.py @@ -218,7 +218,8 @@ def main(exp, frame_sizes, dataset, **params): )) trainer.register_plugin(GeneratorPlugin( os.path.join(results_path, 'samples'), params['n_samples'], - params['sample_length'], params['sample_rate'] + params['sample_length'], params['sample_rate'], + params['primer'] )) trainer.register_plugin( Logger([ @@ -355,6 +356,9 @@ if __name__ == '__main__': parser.add_argument( '--comet_key', help='comet.ml API key' ) + parser.add_argument( + '--primer', help='prime the generator...' + ) parser.set_defaults(**default_params) diff --git a/trainer/plugins.py b/trainer/plugins.py index 0126870..dc3b24a 100644 --- a/trainer/plugins.py +++ b/trainer/plugins.py @@ -143,18 +143,19 @@ class GeneratorPlugin(Plugin): pattern = 'd-{}-ep{}-s{}.wav' - def __init__(self, samples_path, n_samples, sample_length, sample_rate): + def __init__(self, samples_path, n_samples, sample_length, sample_rate, primer): super().__init__([(1, 'epoch')]) self.samples_path = samples_path self.n_samples = n_samples self.sample_length = sample_length self.sample_rate = sample_rate + self.primer = primer def register(self, trainer): self.generate = Generator(trainer.model.model, trainer.cuda) def epoch(self, epoch_index): - samples = self.generate(self.n_samples, self.sample_length) \ + samples = self.generate(self.n_samples, self.sample_length, self.primer) \ .cpu().float().numpy() for i in range(self.n_samples): write_wav( -- cgit v1.2.3-70-g09d2