From 2e308fe8e90276a892637be1bfa174e673ebf414 Mon Sep 17 00:00:00 2001 From: Piotr Kozakowski Date: Thu, 11 May 2017 17:49:12 +0200 Subject: Implement SampleRNN --- sample.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 sample.py (limited to 'sample.py') diff --git a/sample.py b/sample.py new file mode 100644 index 0000000..775fd24 --- /dev/null +++ b/sample.py @@ -0,0 +1,34 @@ +from model import SampleRNN, Predictor, Generator +from trainer import Trainer, sequence_nll_loss +from dataset import FolderDataset, DataLoader + +import torch +from torch.utils.trainer import plugins + +from librosa.output import write_wav + +from time import time + + +def main(): + model = SampleRNN( + frame_sizes=[16, 4], n_rnn=1, dim=1024, learn_h0=True, q_levels=256 + ) + predictor = Predictor(model).cuda() + predictor.load_state_dict(torch.load('model.tar')) + + generator = Generator(predictor.model, cuda=True) + + t = time() + samples = generator(5, 16000) + print('generated in {}s'.format(time() - t)) + + write_wav( + 'sample.wav', + samples.cpu().float().numpy()[0, :], + sr=16000, + norm=True + ) + +if __name__ == '__main__': + main() -- cgit v1.2.3-70-g09d2