diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-05-14 19:15:58 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-05-14 19:15:58 +0200 |
| commit | 60fb2b7c87b7e6aa179c6a973a8d6e39cbe7c594 (patch) | |
| tree | 8a738a43e8583f38f151cdc643a38b5a9437cda2 /trainer | |
| parent | 9766ef0f3a539be7ee68bb93918f25a3298afe39 (diff) | |
| parent | e2d8a6f26c5e44d970d7c069f171105376835495 (diff) | |
Merge branch 'master' of asdf.us:samplernn
Diffstat (limited to 'trainer')
| -rw-r--r-- | trainer/__init__.py | 9 | ||||
| -rw-r--r-- | trainer/plugins.py | 8 |
2 files changed, 13 insertions, 4 deletions
diff --git a/trainer/__init__.py b/trainer/__init__.py index 7e2ea18..1f39506 100644 --- a/trainer/__init__.py +++ b/trainer/__init__.py @@ -56,6 +56,15 @@ class Trainer(object): self.train() self.call_plugins('epoch', self.epochs) + def generate(self, epochs=1): + for q in self.plugin_queues.values(): + heapq.heapify(q) + + for self.epochs in range(self.epochs + 1, self.epochs + epochs + 1): + # self.train() + self.call_plugins('update', self.iterations, self.model) + self.call_plugins('epoch', self.epochs) + def train(self): for (self.iterations, data) in \ enumerate(self.dataset, self.iterations + 1): diff --git a/trainer/plugins.py b/trainer/plugins.py index f8c299b..0126870 100644 --- a/trainer/plugins.py +++ b/trainer/plugins.py @@ -141,7 +141,7 @@ class SaverPlugin(Plugin): class GeneratorPlugin(Plugin): - pattern = 'ep{}-s{}.wav' + pattern = 'd-{}-ep{}-s{}.wav' def __init__(self, samples_path, n_samples, sample_length, sample_rate): super().__init__([(1, 'epoch')]) @@ -159,7 +159,7 @@ class GeneratorPlugin(Plugin): for i in range(self.n_samples): write_wav( os.path.join( - self.samples_path, self.pattern.format(epoch_index, i + 1) + self.samples_path, self.pattern.format(int(time.time()), epoch_index, i + 1) ), samples[i, :], sr=self.sample_rate, norm=True ) @@ -168,7 +168,7 @@ class GeneratorPlugin(Plugin): class StatsPlugin(Plugin): data_file_name = 'stats.pkl' - plot_pattern = '{}.svg' + plot_pattern = 'd-{}-{}.svg' def __init__(self, results_path, iteration_fields, epoch_fields, plots): super().__init__([(1, 'iteration'), (1, 'epoch')]) @@ -252,7 +252,7 @@ class StatsPlugin(Plugin): pyplot.legend() pyplot.savefig( - os.path.join(self.results_path, self.plot_pattern.format(name)) + os.path.join(self.results_path, self.plot_pattern.format(int(time.time()), name)) ) @staticmethod |
