summaryrefslogtreecommitdiff
path: root/sample.py
diff options
context:
space:
mode:
authorPiotr Kozakowski <kozak000@gmail.com>2017-05-11 17:49:12 +0200
committerPiotr Kozakowski <kozak000@gmail.com>2017-06-29 15:37:26 +0200
commit2e308fe8e90276a892637be1bfa174e673ebf414 (patch)
tree4ff187b37d16476cc936aba84184b8feca9c8612 /sample.py
parent253860fdb0949f0eab6abff09369b0a1236b541a (diff)
Implement SampleRNN
Diffstat (limited to 'sample.py')
-rw-r--r--sample.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/sample.py b/sample.py
new file mode 100644
index 0000000..775fd24
--- /dev/null
+++ b/sample.py
@@ -0,0 +1,34 @@
+from model import SampleRNN, Predictor, Generator
+from trainer import Trainer, sequence_nll_loss
+from dataset import FolderDataset, DataLoader
+
+import torch
+from torch.utils.trainer import plugins
+
+from librosa.output import write_wav
+
+from time import time
+
+
+def main():
+ model = SampleRNN(
+ frame_sizes=[16, 4], n_rnn=1, dim=1024, learn_h0=True, q_levels=256
+ )
+ predictor = Predictor(model).cuda()
+ predictor.load_state_dict(torch.load('model.tar'))
+
+ generator = Generator(predictor.model, cuda=True)
+
+ t = time()
+ samples = generator(5, 16000)
+ print('generated in {}s'.format(time() - t))
+
+ write_wav(
+ 'sample.wav',
+ samples.cpu().float().numpy()[0, :],
+ sr=16000,
+ norm=True
+ )
+
+if __name__ == '__main__':
+ main()