summaryrefslogtreecommitdiff
path: root/sample.py
blob: 775fd24e128f027759533a6036d0bd13dc95220e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from model import SampleRNN, Predictor, Generator
from trainer import Trainer, sequence_nll_loss
from dataset import FolderDataset, DataLoader

import torch
from torch.utils.trainer import plugins

from librosa.output import write_wav

from time import time


def main():
    model = SampleRNN(
        frame_sizes=[16, 4], n_rnn=1, dim=1024, learn_h0=True, q_levels=256
    )
    predictor = Predictor(model).cuda()
    predictor.load_state_dict(torch.load('model.tar'))

    generator = Generator(predictor.model, cuda=True)

    t = time()
    samples = generator(5, 16000)
    print('generated in {}s'.format(time() - t))

    write_wav(
        'sample.wav',
        samples.cpu().float().numpy()[0, :],
        sr=16000,
        norm=True
    )

if __name__ == '__main__':
    main()