summaryrefslogtreecommitdiff
path: root/trainer
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-05-14 20:38:13 +0200
committerJules Laplace <julescarbon@gmail.com>2018-05-14 20:38:13 +0200
commit914ffba685b594c11bbfb8559a44f54042e40d83 (patch)
tree777948436d295dce892ffc817a2bddb7f7ec176b /trainer
parent45e252f2fc33992b52dc34251c0ba31970b92ee5 (diff)
get pl param printing being nice
Diffstat (limited to 'trainer')
-rw-r--r--trainer/plugins.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/trainer/plugins.py b/trainer/plugins.py
index 562355e..132c33d 100644
--- a/trainer/plugins.py
+++ b/trainer/plugins.py
@@ -143,13 +143,16 @@ class GeneratorPlugin(Plugin):
pattern = 'd-{}-ep{}-s{}.wav'
- def __init__(self, samples_path, n_samples, sample_length, sample_rate, primer):
+ def __init__(self, samples_path, n_samples, sample_length, sample_rate, primer, primer_a, primer_b, recursive):
super().__init__([(1, 'epoch')])
self.samples_path = samples_path
self.n_samples = n_samples
self.sample_length = sample_length
self.sample_rate = sample_rate
self.primer = primer
+ self.primer_a = primer_a
+ self.primer_b = primer_b
+ self.recursive = recursive
def register(self, trainer):
if self.primer == "":
@@ -158,7 +161,7 @@ class GeneratorPlugin(Plugin):
self.generate = PrimedGenerator(trainer.model.model, trainer.cuda)
def epoch(self, epoch_index):
- samples = self.generate(self.n_samples, self.sample_length, self.primer) \
+ samples = self.generate(self.n_samples, self.sample_length, self.primer, self.primer_a, self.primer_b, self.recursive) \
.cpu().float().numpy()
for i in range(self.n_samples):
write_wav(