From 914ffba685b594c11bbfb8559a44f54042e40d83 Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Mon, 14 May 2018 20:38:13 +0200 Subject: get pl param printing being nice --- trainer/plugins.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'trainer') diff --git a/trainer/plugins.py b/trainer/plugins.py index 562355e..132c33d 100644 --- a/trainer/plugins.py +++ b/trainer/plugins.py @@ -143,13 +143,16 @@ class GeneratorPlugin(Plugin): pattern = 'd-{}-ep{}-s{}.wav' - def __init__(self, samples_path, n_samples, sample_length, sample_rate, primer): + def __init__(self, samples_path, n_samples, sample_length, sample_rate, primer, primer_a, primer_b, recursive): super().__init__([(1, 'epoch')]) self.samples_path = samples_path self.n_samples = n_samples self.sample_length = sample_length self.sample_rate = sample_rate self.primer = primer + self.primer_a = primer_a + self.primer_b = primer_b + self.recursive = recursive def register(self, trainer): if self.primer == "": @@ -158,7 +161,7 @@ class GeneratorPlugin(Plugin): self.generate = PrimedGenerator(trainer.model.model, trainer.cuda) def epoch(self, epoch_index): - samples = self.generate(self.n_samples, self.sample_length, self.primer) \ + samples = self.generate(self.n_samples, self.sample_length, self.primer, self.primer_a, self.primer_b, self.recursive) \ .cpu().float().numpy() for i in range(self.n_samples): write_wav( -- cgit v1.2.3-70-g09d2