summaryrefslogtreecommitdiff
path: root/sample.py
diff options
context:
space:
mode:
Diffstat (limited to 'sample.py')
-rw-r--r--sample.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/sample.py b/sample.py
new file mode 100644
index 0000000..775fd24
--- /dev/null
+++ b/sample.py
@@ -0,0 +1,34 @@
+from model import SampleRNN, Predictor, Generator
+from trainer import Trainer, sequence_nll_loss
+from dataset import FolderDataset, DataLoader
+
+import torch
+from torch.utils.trainer import plugins
+
+from librosa.output import write_wav
+
+from time import time
+
+
+def main():
+ model = SampleRNN(
+ frame_sizes=[16, 4], n_rnn=1, dim=1024, learn_h0=True, q_levels=256
+ )
+ predictor = Predictor(model).cuda()
+ predictor.load_state_dict(torch.load('model.tar'))
+
+ generator = Generator(predictor.model, cuda=True)
+
+ t = time()
+ samples = generator(5, 16000)
+ print('generated in {}s'.format(time() - t))
+
+ write_wav(
+ 'sample.wav',
+ samples.cpu().float().numpy()[0, :],
+ sr=16000,
+ norm=True
+ )
+
+if __name__ == '__main__':
+ main()