diff options
| author | Piotr Kozakowski <kozak000@gmail.com> | 2017-05-11 17:49:12 +0200 |
|---|---|---|
| committer | Piotr Kozakowski <kozak000@gmail.com> | 2017-06-29 15:37:26 +0200 |
| commit | 2e308fe8e90276a892637be1bfa174e673ebf414 (patch) | |
| tree | 4ff187b37d16476cc936aba84184b8feca9c8612 /sample.py | |
| parent | 253860fdb0949f0eab6abff09369b0a1236b541a (diff) | |
Implement SampleRNN
Diffstat (limited to 'sample.py')
| -rw-r--r-- | sample.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/sample.py b/sample.py new file mode 100644 index 0000000..775fd24 --- /dev/null +++ b/sample.py @@ -0,0 +1,34 @@ +from model import SampleRNN, Predictor, Generator +from trainer import Trainer, sequence_nll_loss +from dataset import FolderDataset, DataLoader + +import torch +from torch.utils.trainer import plugins + +from librosa.output import write_wav + +from time import time + + +def main(): + model = SampleRNN( + frame_sizes=[16, 4], n_rnn=1, dim=1024, learn_h0=True, q_levels=256 + ) + predictor = Predictor(model).cuda() + predictor.load_state_dict(torch.load('model.tar')) + + generator = Generator(predictor.model, cuda=True) + + t = time() + samples = generator(5, 16000) + print('generated in {}s'.format(time() - t)) + + write_wav( + 'sample.wav', + samples.cpu().float().numpy()[0, :], + sr=16000, + norm=True + ) + +if __name__ == '__main__': + main() |
