diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-05-14 19:20:31 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-05-14 19:20:31 +0200 |
| commit | 88eecd62b75cc032752aa10121d376cc7bca418b (patch) | |
| tree | 9c1a5dfb99c0d88d0523e6c96380257a0d474939 /trainer | |
| parent | 60fb2b7c87b7e6aa179c6a973a8d6e39cbe7c594 (diff) | |
flag to prime the generator
Diffstat (limited to 'trainer')
| -rw-r--r-- | trainer/plugins.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/trainer/plugins.py b/trainer/plugins.py index 0126870..dc3b24a 100644 --- a/trainer/plugins.py +++ b/trainer/plugins.py @@ -143,18 +143,19 @@ class GeneratorPlugin(Plugin): pattern = 'd-{}-ep{}-s{}.wav' - def __init__(self, samples_path, n_samples, sample_length, sample_rate): + def __init__(self, samples_path, n_samples, sample_length, sample_rate, primer): super().__init__([(1, 'epoch')]) self.samples_path = samples_path self.n_samples = n_samples self.sample_length = sample_length self.sample_rate = sample_rate + self.primer = primer def register(self, trainer): self.generate = Generator(trainer.model.model, trainer.cuda) def epoch(self, epoch_index): - samples = self.generate(self.n_samples, self.sample_length) \ + samples = self.generate(self.n_samples, self.sample_length, self.primer) \ .cpu().float().numpy() for i in range(self.n_samples): write_wav( |
