summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--generate.py9
-rw-r--r--train.py6
-rw-r--r--trainer/plugins.py5
4 files changed, 17 insertions, 5 deletions
diff --git a/.gitignore b/.gitignore
index 26a1d76..2c235f0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -113,3 +113,5 @@ results/old_checkpoints/
mgb
run_*
+.DS_Store
+
diff --git a/generate.py b/generate.py
index 92a930f..fb63b09 100644
--- a/generate.py
+++ b/generate.py
@@ -51,7 +51,8 @@ default_params = {
'sample_length': 80000,
'loss_smoothing': 0.99,
'cuda': True,
- 'comet_key': None
+ 'comet_key': None,
+ 'primer': 'zero'
}
tag_params = [
@@ -222,7 +223,8 @@ def main(exp, frame_sizes, dataset, **params):
"""
trainer.register_plugin(GeneratorPlugin(
os.path.join(results_path, 'samples'), params['n_samples'],
- params['sample_length'], params['sample_rate']
+ params['sample_length'], params['sample_rate'],
+ params['primer']
))
"""
trainer.register_plugin(
@@ -359,6 +361,9 @@ if __name__ == '__main__':
parser.add_argument(
'--comet_key', help='comet.ml API key'
)
+ parser.add_argument(
+ '--primer', help='prime the generator...'
+ )
parser.set_defaults(**default_params)
diff --git a/train.py b/train.py
index 1f4ab3b..a40e5f6 100644
--- a/train.py
+++ b/train.py
@@ -218,7 +218,8 @@ def main(exp, frame_sizes, dataset, **params):
))
trainer.register_plugin(GeneratorPlugin(
os.path.join(results_path, 'samples'), params['n_samples'],
- params['sample_length'], params['sample_rate']
+ params['sample_length'], params['sample_rate'],
+ params['primer']
))
trainer.register_plugin(
Logger([
@@ -355,6 +356,9 @@ if __name__ == '__main__':
parser.add_argument(
'--comet_key', help='comet.ml API key'
)
+ parser.add_argument(
+ '--primer', help='prime the generator...'
+ )
parser.set_defaults(**default_params)
diff --git a/trainer/plugins.py b/trainer/plugins.py
index 0126870..dc3b24a 100644
--- a/trainer/plugins.py
+++ b/trainer/plugins.py
@@ -143,18 +143,19 @@ class GeneratorPlugin(Plugin):
pattern = 'd-{}-ep{}-s{}.wav'
- def __init__(self, samples_path, n_samples, sample_length, sample_rate):
+ def __init__(self, samples_path, n_samples, sample_length, sample_rate, primer):
super().__init__([(1, 'epoch')])
self.samples_path = samples_path
self.n_samples = n_samples
self.sample_length = sample_length
self.sample_rate = sample_rate
+ self.primer = primer
def register(self, trainer):
self.generate = Generator(trainer.model.model, trainer.cuda)
def epoch(self, epoch_index):
- samples = self.generate(self.n_samples, self.sample_length) \
+ samples = self.generate(self.n_samples, self.sample_length, self.primer) \
.cpu().float().numpy()
for i in range(self.n_samples):
write_wav(