summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPiotr Kozakowski <kozak000@gmail.com>2017-05-11 17:49:12 +0200
committerPiotr Kozakowski <kozak000@gmail.com>2017-06-29 15:37:26 +0200
commit2e308fe8e90276a892637be1bfa174e673ebf414 (patch)
tree4ff187b37d16476cc936aba84184b8feca9c8612
parent253860fdb0949f0eab6abff09369b0a1236b541a (diff)
Implement SampleRNN
-rw-r--r--.gitignore5
-rw-r--r--README.md33
-rw-r--r--dataset.py68
-rwxr-xr-xdatasets/download-from-youtube.sh30
-rw-r--r--model.py286
-rw-r--r--nn.py70
-rw-r--r--optim.py21
-rw-r--r--sample.py34
-rw-r--r--train.py316
-rw-r--r--trainer/__init__.py100
-rw-r--r--trainer/plugins.py267
-rw-r--r--utils.py20
12 files changed, 1249 insertions, 1 deletions
diff --git a/.gitignore b/.gitignore
index 72364f9..9c17115 100644
--- a/.gitignore
+++ b/.gitignore
@@ -87,3 +87,8 @@ ENV/
# Rope project settings
.ropeproject
+
+# vim temporary files
+*~
+*.swp
+*.swo
diff --git a/README.md b/README.md
index 9f38e18..440dfe7 100644
--- a/README.md
+++ b/README.md
@@ -1 +1,32 @@
-# samplernn-pytorch \ No newline at end of file
+# samplernn-pytorch
+
+A PyTorch implementation of [SampleRNN: An Unconditional End-to-End Neural Audio Generation Model](https://arxiv.org/abs/1612.07837).
+
+![A visual representation of the SampleRNN architecture](http://deepsound.io/images/samplernn.png)
+
+It's based on the reference implementation in Theano: https://github.com/soroushmehr/sampleRNN_ICLR2017. Unlike the Theano version, our code allows training models with arbitrary number of tiers, whereas the original implementation allows maximum 3 tiers. However it doesn't have weight normalization and doesn't allow using LSTM units (only GRU). For more details and motivation behind rewriting this model to PyTorch, see our blog post: http://deepsound.io/samplernn_pytorch.html.
+
+## Dependencies
+
+This code requires Python 3.5+ and PyTorch 0.1.12+. Installation instructions for PyTorch are available on their website: http://pytorch.org/. You can install the rest of the dependencies by running `pip install -r requirements.txt`.
+
+## Datasets
+
+We provide a script for creating datasets from YouTube single-video mixes. It downloads a mix, converts it to wav and splits it into equal-length chunks. To run it you need youtube-dl (a recent version; the latest version from pip should be okay) and ffmpeg. To create an example dataset - 4 hours of piano music split into 8 second chunks, run:
+
+```
+cd datasets
+./download-from-youtube.sh "https://www.youtube.com/watch?v=EhO_MrRfftU" 8 piano
+```
+
+You can also prepare a dataset yourself. It should be a directory in `datasets/` filled with equal-length wav files. Or you can create your own dataset format by subclassing `torch.utils.data.Dataset`. It's easy, take a look at `dataset.FolderDataset` in this repo for an example.
+
+## Training
+
+To train the model you need to run `train.py`. All model hyperparameters are settable in the command line. Most hyperparameters have sensible default values, so you don't need to provide all od them. Run `python train.py -h` for details. To train on the `piano` dataset using the best hyperparameters we've found, run:
+
+```
+python train.py --exp TEST --frame_sizes 16 4 --n_rnn 2 --dataset piano
+```
+
+The results - training log, loss plots, model checkpoints and generated samples will be saved in `results/`.
diff --git a/dataset.py b/dataset.py
new file mode 100644
index 0000000..3bb035f
--- /dev/null
+++ b/dataset.py
@@ -0,0 +1,68 @@
+import utils
+
+import torch
+from torch.utils.data import (
+ Dataset, DataLoader as DataLoaderBase
+)
+
+from librosa.core import load
+from natsort import natsorted
+
+from os import listdir
+from os.path import join
+
+
+class FolderDataset(Dataset):
+
+ def __init__(self, path, overlap_len, q_levels, ratio_min=0, ratio_max=1):
+ super().__init__()
+ self.overlap_len = overlap_len
+ self.q_levels = q_levels
+ file_names = natsorted(
+ [join(path, file_name) for file_name in listdir(path)]
+ )
+ self.file_names = file_names[
+ int(ratio_min * len(file_names)) : int(ratio_max * len(file_names))
+ ]
+
+ def __getitem__(self, index):
+ (seq, _) = load(self.file_names[index], sr=None, mono=True)
+ return torch.cat([
+ torch.LongTensor(self.overlap_len) \
+ .fill_(utils.q_zero(self.q_levels)),
+ utils.linear_quantize(
+ torch.from_numpy(seq), self.q_levels
+ )
+ ])
+
+ def __len__(self):
+ return len(self.file_names)
+
+
+class DataLoader(DataLoaderBase):
+
+ def __init__(self, dataset, batch_size, seq_len, overlap_len,
+ *args, **kwargs):
+ super().__init__(dataset, batch_size, *args, **kwargs)
+ self.seq_len = seq_len
+ self.overlap_len = overlap_len
+
+ def __iter__(self):
+ for batch in super().__iter__():
+ (batch_size, n_samples) = batch.size()
+
+ reset = True
+
+ for seq_begin in range(self.overlap_len, n_samples, self.seq_len):
+ from_index = seq_begin - self.overlap_len
+ to_index = seq_begin + self.seq_len
+ sequences = batch[:, from_index : to_index]
+ input_sequences = sequences[:, : -1]
+ target_sequences = sequences[:, self.overlap_len :]
+
+ yield (input_sequences, reset, target_sequences)
+
+ reset = False
+
+ def __len__(self):
+ raise NotImplementedError()
diff --git a/datasets/download-from-youtube.sh b/datasets/download-from-youtube.sh
new file mode 100755
index 0000000..2bcec33
--- /dev/null
+++ b/datasets/download-from-youtube.sh
@@ -0,0 +1,30 @@
+#!/bin/sh
+
+if [ "$#" -ne 3 ]; then
+ echo "Usage: $0 <youtube url> <chunk size in seconds> <dataset path>"
+ exit
+fi
+
+url=$1
+chunk_size=$2
+dataset_path=$3
+
+downloaded=".temp"
+rm -f $downloaded
+format=$(youtube-dl -F $url | grep audio | sed -r 's|([0-9]+).*|\1|g' | tail -n 1)
+youtube-dl $url -f $format -o $downloaded
+
+converted=".temp2.wav"
+rm -f $converted
+ffmpeg -i $downloaded -ac 1 -ab 16k -ar 16000 $converted
+rm -f $downloaded
+
+mkdir $dataset_path
+length=$(ffprobe -i $converted -show_entries format=duration -v quiet -of csv="p=0")
+end=$(echo "$length / $chunk_size - 1" | bc)
+echo "splitting..."
+for i in $(seq 0 $end); do
+ ffmpeg -hide_banner -loglevel error -ss $(($i * $chunk_size)) -t $chunk_size -i $converted "$dataset_path/$i.wav"
+done
+echo "done"
+rm -f $converted
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..d2db1a9
--- /dev/null
+++ b/model.py
@@ -0,0 +1,286 @@
+import nn
+import utils
+
+import torch
+from torch.nn import functional as F
+from torch.nn import init
+
+import numpy as np
+
+
+class SampleRNN(torch.nn.Module):
+
+ def __init__(self, frame_sizes, n_rnn, dim, learn_h0, q_levels):
+ super().__init__()
+
+ self.dim = dim
+ self.q_levels = q_levels
+
+ ns_frame_samples = map(int, np.cumprod(frame_sizes))
+ self.frame_level_rnns = torch.nn.ModuleList([
+ FrameLevelRNN(
+ frame_size, n_frame_samples, n_rnn, dim, learn_h0
+ )
+ for (frame_size, n_frame_samples) in zip(
+ frame_sizes, ns_frame_samples
+ )
+ ])
+
+ self.sample_level_mlp = SampleLevelMLP(frame_sizes[0], dim, q_levels)
+
+ @property
+ def lookback(self):
+ return self.frame_level_rnns[-1].n_frame_samples
+
+
+class FrameLevelRNN(torch.nn.Module):
+
+ def __init__(self, frame_size, n_frame_samples, n_rnn, dim,
+ learn_h0):
+ super().__init__()
+
+ self.frame_size = frame_size
+ self.n_frame_samples = n_frame_samples
+ self.dim = dim
+
+ h0 = torch.zeros(n_rnn, dim)
+ if learn_h0:
+ self.h0 = torch.nn.Parameter(h0)
+ else:
+ self.register_buffer('h0', torch.autograd.Variable(h0))
+
+ self.input_expand = torch.nn.Conv1d(
+ in_channels=n_frame_samples,
+ out_channels=dim,
+ kernel_size=1
+ )
+ init.kaiming_uniform(self.input_expand.weight)
+ init.constant(self.input_expand.bias, 0)
+
+ self.rnn = torch.nn.GRU(
+ input_size=dim,
+ hidden_size=dim,
+ num_layers=n_rnn,
+ batch_first=True
+ )
+ for i in range(n_rnn):
+ nn.concat_init(
+ getattr(self.rnn, 'weight_ih_l{}'.format(i)),
+ [nn.lecun_uniform, nn.lecun_uniform, nn.lecun_uniform]
+ )
+ init.constant(getattr(self.rnn, 'bias_ih_l{}'.format(i)), 0)
+
+ nn.concat_init(
+ getattr(self.rnn, 'weight_hh_l{}'.format(i)),
+ [nn.lecun_uniform, nn.lecun_uniform, init.orthogonal]
+ )
+ init.constant(getattr(self.rnn, 'bias_hh_l{}'.format(i)), 0)
+
+ self.upsampling = nn.LearnedUpsampling1d(
+ in_channels=dim,
+ out_channels=dim,
+ kernel_size=frame_size
+ )
+ init.uniform(
+ self.upsampling.conv_t.weight, -np.sqrt(6 / dim), np.sqrt(6 / dim)
+ )
+ init.constant(self.upsampling.bias, 0)
+
+ def forward(self, prev_samples, upper_tier_conditioning, hidden):
+ (batch_size, _, _) = prev_samples.size()
+
+ input = self.input_expand(
+ prev_samples.permute(0, 2, 1)
+ ).permute(0, 2, 1)
+ if upper_tier_conditioning is not None:
+ input += upper_tier_conditioning
+
+ reset = hidden is None
+
+ if hidden is None:
+ (n_rnn, _) = self.h0.size()
+ hidden = self.h0.unsqueeze(1) \
+ .expand(n_rnn, batch_size, self.dim) \
+ .contiguous()
+
+ (output, hidden) = self.rnn(input, hidden)
+
+ output = self.upsampling(
+ output.permute(0, 2, 1)
+ ).permute(0, 2, 1)
+ return (output, hidden)
+
+
+class SampleLevelMLP(torch.nn.Module):
+
+ def __init__(self, frame_size, dim, q_levels):
+ super().__init__()
+
+ self.q_levels = q_levels
+
+ self.embedding = torch.nn.Embedding(
+ self.q_levels,
+ self.q_levels
+ )
+
+ self.input = torch.nn.Conv1d(
+ in_channels=q_levels,
+ out_channels=dim,
+ kernel_size=frame_size,
+ bias=False
+ )
+ init.kaiming_uniform(self.input.weight)
+
+ self.hidden = torch.nn.Conv1d(
+ in_channels=dim,
+ out_channels=dim,
+ kernel_size=1
+ )
+ init.kaiming_uniform(self.hidden.weight)
+ init.constant(self.hidden.bias, 0)
+
+ self.output = torch.nn.Conv1d(
+ in_channels=dim,
+ out_channels=q_levels,
+ kernel_size=1
+ )
+ nn.lecun_uniform(self.output.weight)
+ init.constant(self.output.bias, 0)
+
+ def forward(self, prev_samples, upper_tier_conditioning):
+ (batch_size, _, _) = upper_tier_conditioning.size()
+
+ prev_samples = self.embedding(
+ prev_samples.contiguous().view(-1)
+ ).view(
+ batch_size, -1, self.q_levels
+ )
+
+ prev_samples = prev_samples.permute(0, 2, 1)
+ upper_tier_conditioning = upper_tier_conditioning.permute(0, 2, 1)
+
+ x = F.relu(self.input(prev_samples) + upper_tier_conditioning)
+ x = F.relu(self.hidden(x))
+ x = self.output(x).permute(0, 2, 1).contiguous()
+
+ return F.log_softmax(x.view(-1, self.q_levels)) \
+ .view(batch_size, -1, self.q_levels)
+
+
+class Runner:
+
+ def __init__(self, model):
+ super().__init__()
+ self.model = model
+ self.reset_hidden_states()
+
+ def reset_hidden_states(self):
+ self.hidden_states = {rnn: None for rnn in self.model.frame_level_rnns}
+
+ def run_rnn(self, rnn, prev_samples, upper_tier_conditioning):
+ (output, new_hidden) = rnn(
+ prev_samples, upper_tier_conditioning, self.hidden_states[rnn]
+ )
+ self.hidden_states[rnn] = new_hidden.detach()
+ return output
+
+
+class Predictor(Runner, torch.nn.Module):
+
+ def __init__(self, model):
+ super().__init__(model)
+
+ def forward(self, input_sequences, reset):
+ if reset:
+ self.reset_hidden_states()
+
+ (batch_size, _) = input_sequences.size()
+
+ upper_tier_conditioning = None
+ for rnn in reversed(self.model.frame_level_rnns):
+ from_index = self.model.lookback - rnn.n_frame_samples
+ to_index = -rnn.n_frame_samples + 1
+ prev_samples = 2 * utils.linear_dequantize(
+ input_sequences[:, from_index : to_index],
+ self.model.q_levels
+ )
+ prev_samples = prev_samples.contiguous().view(
+ batch_size, -1, rnn.n_frame_samples
+ )
+
+ upper_tier_conditioning = self.run_rnn(
+ rnn, prev_samples, upper_tier_conditioning
+ )
+
+ bottom_frame_size = self.model.frame_level_rnns[0].frame_size
+ mlp_input_sequences = input_sequences \
+ [:, self.model.lookback - bottom_frame_size :]
+
+ return self.model.sample_level_mlp(
+ mlp_input_sequences, upper_tier_conditioning
+ )
+
+
+class Generator(Runner):
+
+ def __init__(self, model, cuda=False):
+ super().__init__(model)
+ self.cuda = cuda
+
+ def __call__(self, n_seqs, seq_len):
+ # generation doesn't work with CUDNN for some reason
+ torch.backends.cudnn.enabled = False
+
+ self.reset_hidden_states()
+
+ bottom_frame_size = self.model.frame_level_rnns[0].n_frame_samples
+ sequences = torch.LongTensor(n_seqs, self.model.lookback + seq_len) \
+ .fill_(utils.q_zero(self.model.q_levels))
+ frame_level_outputs = [None for _ in self.model.frame_level_rnns]
+
+ for i in range(self.model.lookback, self.model.lookback + seq_len):
+ for (tier_index, rnn) in \
+ reversed(list(enumerate(self.model.frame_level_rnns))):
+ if i % rnn.n_frame_samples != 0:
+ continue
+
+ prev_samples = torch.autograd.Variable(
+ 2 * utils.linear_dequantize(
+ sequences[:, i - rnn.n_frame_samples : i],
+ self.model.q_levels
+ ).unsqueeze(1),
+ volatile=True
+ )
+ if self.cuda:
+ prev_samples = prev_samples.cuda()
+
+ if tier_index == len(self.model.frame_level_rnns) - 1:
+ upper_tier_conditioning = None
+ else:
+ frame_index = (i // rnn.n_frame_samples) % \
+ self.model.frame_level_rnns[tier_index + 1].frame_size
+ upper_tier_conditioning = \
+ frame_level_outputs[tier_index + 1][:, frame_index, :] \
+ .unsqueeze(1)
+
+ frame_level_outputs[tier_index] = self.run_rnn(
+ rnn, prev_samples, upper_tier_conditioning
+ )
+
+ prev_samples = torch.autograd.Variable(
+ sequences[:, i - bottom_frame_size : i],
+ volatile=True
+ )
+ if self.cuda:
+ prev_samples = prev_samples.cuda()
+ upper_tier_conditioning = \
+ frame_level_outputs[0][:, i % bottom_frame_size, :] \
+ .unsqueeze(1)
+ sample_dist = self.model.sample_level_mlp(
+ prev_samples, upper_tier_conditioning
+ ).squeeze(1).exp_().data
+ sequences[:, i] = sample_dist.multinomial(1).squeeze(1)
+
+ torch.backends.cudnn.enabled = True
+
+ return sequences[:, self.model.lookback :]
diff --git a/nn.py b/nn.py
new file mode 100644
index 0000000..db47414
--- /dev/null
+++ b/nn.py
@@ -0,0 +1,70 @@
+import torch
+from torch import nn
+
+import math
+
+
+class LearnedUpsampling1d(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, bias=True):
+ super().__init__()
+
+ self.conv_t = nn.ConvTranspose1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=kernel_size,
+ bias=False
+ )
+
+ if bias:
+ self.bias = nn.Parameter(
+ torch.FloatTensor(out_channels, kernel_size)
+ )
+ else:
+ self.register_parameter('bias', None)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ self.conv_t.reset_parameters()
+ nn.init.constant(self.bias, 0)
+
+ def forward(self, input):
+ (batch_size, _, length) = input.size()
+ (kernel_size,) = self.conv_t.kernel_size
+ bias = self.bias.unsqueeze(0).unsqueeze(2).expand(
+ batch_size, self.conv_t.out_channels,
+ length, kernel_size
+ ).contiguous().view(
+ batch_size, self.conv_t.out_channels,
+ length * kernel_size
+ )
+ return self.conv_t(input) + bias
+
+
+def lecun_uniform(tensor):
+ fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
+ nn.init.uniform(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
+
+
+def concat_init(tensor, inits):
+ try:
+ tensor = tensor.data
+ except AttributeError:
+ pass
+
+ (length, fan_out) = tensor.size()
+ fan_in = length // len(inits)
+
+ chunk = tensor.new(fan_in, fan_out)
+ for (i, init) in enumerate(inits):
+ init(chunk)
+ tensor[i * fan_in : (i + 1) * fan_in, :] = chunk
+
+
+def sequence_nll_loss_bits(input, target, *args, **kwargs):
+ (_, _, n_classes) = input.size()
+ return nn.functional.nll_loss(
+ input.view(-1, n_classes), target.view(-1), *args, **kwargs
+ ) * math.log(math.e, 2)
diff --git a/optim.py b/optim.py
new file mode 100644
index 0000000..95f9df2
--- /dev/null
+++ b/optim.py
@@ -0,0 +1,21 @@
+from torch.nn.functional import hardtanh
+
+
+def gradient_clipping(optimizer, min=-1, max=1):
+
+ class OptimizerWrapper(object):
+
+ def step(self, closure):
+ def closure_wrapper():
+ loss = closure()
+ for group in optimizer.param_groups:
+ for p in group['params']:
+ hardtanh(p.grad, min, max, inplace=True)
+ return loss
+
+ return optimizer.step(closure_wrapper)
+
+ def __getattr__(self, attr):
+ return getattr(optimizer, attr)
+
+ return OptimizerWrapper()
diff --git a/sample.py b/sample.py
new file mode 100644
index 0000000..775fd24
--- /dev/null
+++ b/sample.py
@@ -0,0 +1,34 @@
+from model import SampleRNN, Predictor, Generator
+from trainer import Trainer, sequence_nll_loss
+from dataset import FolderDataset, DataLoader
+
+import torch
+from torch.utils.trainer import plugins
+
+from librosa.output import write_wav
+
+from time import time
+
+
+def main():
+ model = SampleRNN(
+ frame_sizes=[16, 4], n_rnn=1, dim=1024, learn_h0=True, q_levels=256
+ )
+ predictor = Predictor(model).cuda()
+ predictor.load_state_dict(torch.load('model.tar'))
+
+ generator = Generator(predictor.model, cuda=True)
+
+ t = time()
+ samples = generator(5, 16000)
+ print('generated in {}s'.format(time() - t))
+
+ write_wav(
+ 'sample.wav',
+ samples.cpu().float().numpy()[0, :],
+ sr=16000,
+ norm=True
+ )
+
+if __name__ == '__main__':
+ main()
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..c47cd4b
--- /dev/null
+++ b/train.py
@@ -0,0 +1,316 @@
+from model import SampleRNN, Predictor
+from optim import gradient_clipping
+from nn import sequence_nll_loss_bits
+from trainer import Trainer
+from trainer.plugins import (
+ TrainingLossMonitor, ValidationPlugin, AbsoluteTimeMonitor, SaverPlugin,
+ GeneratorPlugin, StatsPlugin
+)
+from dataset import FolderDataset, DataLoader
+
+import torch
+from torch.utils.trainer.plugins import Logger
+
+from natsort import natsorted
+
+from functools import reduce
+import os
+import shutil
+import sys
+from glob import glob
+import re
+import argparse
+
+
+default_params = {
+ # model parameters
+ 'n_rnn': 1,
+ 'dim': 1024,
+ 'learn_h0': True,
+ 'q_levels': 256,
+ 'seq_len': 1024,
+ 'batch_size': 128,
+ 'val_frac': 0.1,
+ 'test_frac': 0.1,
+
+ # training parameters
+ 'keep_old_checkpoints': False,
+ 'datasets_path': 'datasets',
+ 'results_path': 'results',
+ 'epoch_limit': 1000,
+ 'resume': True,
+ 'sample_rate': 16000,
+ 'n_samples': 1,
+ 'sample_length': 80000,
+ 'loss_smoothing': 0.99
+}
+
+tag_params = [
+ 'exp', 'frame_sizes', 'n_rnn', 'dim', 'learn_h0', 'q_levels', 'seq_len',
+ 'batch_size', 'dataset', 'val_frac', 'test_frac'
+]
+
+def make_tag(params):
+ def to_string(value):
+ if isinstance(value, bool):
+ return 'T' if value else 'F'
+ elif isinstance(value, list):
+ return ','.join(map(to_string, value))
+ else:
+ return str(value)
+
+ return '-'.join(
+ key + ':' + to_string(params[key])
+ for key in tag_params
+ if key not in default_params or params[key] != default_params[key]
+ )
+
+def setup_results_dir(params):
+ def ensure_dir_exists(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ tag = make_tag(params)
+ results_path = os.path.abspath(params['results_path'])
+ ensure_dir_exists(results_path)
+ results_path = os.path.join(results_path, tag)
+ if not os.path.exists(results_path):
+ os.makedirs(results_path)
+ elif not params['resume']:
+ shutil.rmtree(results_path)
+ os.makedirs(results_path)
+
+ for subdir in ['checkpoints', 'samples']:
+ ensure_dir_exists(os.path.join(results_path, subdir))
+
+ return results_path
+
+def load_last_checkpoint(checkpoints_path):
+ checkpoints_pattern = os.path.join(
+ checkpoints_path, SaverPlugin.last_pattern.format('*', '*')
+ )
+ checkpoint_paths = natsorted(glob(checkpoints_pattern))
+ if len(checkpoint_paths) > 0:
+ checkpoint_path = checkpoint_paths[-1]
+ checkpoint_name = os.path.basename(checkpoint_path)
+ match = re.match(
+ SaverPlugin.last_pattern.format(r'(\d+)', r'(\d+)'),
+ checkpoint_name
+ )
+ epoch = int(match.group(1))
+ iteration = int(match.group(2))
+ return (torch.load(checkpoint_path), epoch, iteration)
+ else:
+ return None
+
+def tee_stdout(log_path):
+ log_file = open(log_path, 'a', 1)
+ stdout = sys.stdout
+
+ class Tee:
+
+ def write(self, string):
+ log_file.write(string)
+ stdout.write(string)
+
+ def flush(self):
+ log_file.flush()
+ stdout.flush()
+
+ sys.stdout = Tee()
+
+def make_data_loader(overlap_len, params):
+ path = os.path.join(params['datasets_path'], params['dataset'])
+ def data_loader(split_from, split_to, eval):
+ dataset = FolderDataset(
+ path, overlap_len, params['q_levels'], split_from, split_to
+ )
+ return DataLoader(
+ dataset,
+ batch_size=params['batch_size'],
+ seq_len=params['seq_len'],
+ overlap_len=overlap_len,
+ shuffle=(not eval),
+ drop_last=(not eval)
+ )
+ return data_loader
+
+def main(exp, frame_sizes, dataset, **params):
+ params = dict(
+ default_params,
+ exp=exp, frame_sizes=frame_sizes, dataset=dataset,
+ **params
+ )
+
+ results_path = setup_results_dir(params)
+ tee_stdout(os.path.join(results_path, 'log'))
+
+ model = SampleRNN(
+ frame_sizes=params['frame_sizes'],
+ n_rnn=params['n_rnn'],
+ dim=params['dim'],
+ learn_h0=params['learn_h0'],
+ q_levels=params['q_levels']
+ ).cuda()
+ predictor = Predictor(model).cuda()
+
+ optimizer = gradient_clipping(torch.optim.Adam(predictor.parameters()))
+
+ data_loader = make_data_loader(model.lookback, params)
+ test_split = 1 - params['test_frac']
+ val_split = test_split - params['val_frac']
+
+ trainer = Trainer(
+ predictor, sequence_nll_loss_bits, optimizer,
+ data_loader(0, val_split, eval=False),
+ cuda=True
+ )
+
+ checkpoints_path = os.path.join(results_path, 'checkpoints')
+ checkpoint_data = load_last_checkpoint(checkpoints_path)
+ if checkpoint_data is not None:
+ (state_dict, epoch, iteration) = checkpoint_data
+ trainer.epochs = epoch
+ trainer.iterations = iteration
+ predictor.load_state_dict(state_dict)
+
+ trainer.register_plugin(TrainingLossMonitor(
+ smoothing=params['loss_smoothing']
+ ))
+ trainer.register_plugin(ValidationPlugin(
+ data_loader(val_split, test_split, eval=True),
+ data_loader(test_split, 1, eval=True)
+ ))
+ trainer.register_plugin(AbsoluteTimeMonitor())
+ trainer.register_plugin(SaverPlugin(
+ checkpoints_path, params['keep_old_checkpoints']
+ ))
+ trainer.register_plugin(GeneratorPlugin(
+ os.path.join(results_path, 'samples'), params['n_samples'],
+ params['sample_length'], params['sample_rate']
+ ))
+ trainer.register_plugin(
+ Logger([
+ 'training_loss',
+ 'validation_loss',
+ 'test_loss',
+ 'time'
+ ])
+ )
+ trainer.register_plugin(StatsPlugin(
+ results_path,
+ iteration_fields=[
+ 'training_loss',
+ ('training_loss', 'running_avg'),
+ 'time'
+ ],
+ epoch_fields=[
+ 'validation_loss',
+ 'test_loss',
+ 'time'
+ ],
+ plots={
+ 'loss': {
+ 'x': 'iteration',
+ 'ys': [
+ 'training_loss',
+ ('training_loss', 'running_avg'),
+ 'validation_loss',
+ 'test_loss',
+ ],
+ 'log_y': True
+ }
+ }
+ ))
+ trainer.run(params['epoch_limit'])
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ argument_default=argparse.SUPPRESS
+ )
+
+ def parse_bool(arg):
+ arg = arg.lower()
+ if 'true'.startswith(arg):
+ return True
+ elif 'false'.startswith(arg):
+ return False
+ else:
+ raise ValueError()
+
+ parser.add_argument('--exp', required=True, help='experiment name')
+ parser.add_argument(
+ '--frame_sizes', nargs='+', type=int, required=True,
+ help='frame sizes in terms of the number of lower tier frames, \
+ starting from the lowest RNN tier'
+ )
+ parser.add_argument(
+ '--dataset', required=True,
+ help='dataset name - name of a directory in the datasets path \
+ (settable by --datasets_path)'
+ )
+ parser.add_argument(
+ '--n_rnn', type=int, help='number of RNN layers in each tier'
+ )
+ parser.add_argument(
+ '--dim', type=int, help='number of neurons in every RNN and MLP layer'
+ )
+ parser.add_argument(
+ '--learn_h0', type=parse_bool,
+ help='whether to learn the initial states of RNNs'
+ )
+ parser.add_argument(
+ '--q_levels', type=int,
+ help='number of bins in quantization of audio samples'
+ )
+ parser.add_argument(
+ '--seq_len', type=int,
+ help='how many samples to include in each truncated BPTT pass'
+ )
+ parser.add_argument('--batch_size', type=int, help='batch size')
+ parser.add_argument(
+ '--val_frac', type=float,
+ help='fraction of data to go into the validation set'
+ )
+ parser.add_argument(
+ '--test_frac', type=float,
+ help='fraction of data to go into the test set'
+ )
+ parser.add_argument(
+ '--keep_old_checkpoints', type=parse_bool,
+ help='whether to keep checkpoints from past epochs'
+ )
+ parser.add_argument(
+ '--datasets_path', help='path to the directory containing datasets'
+ )
+ parser.add_argument(
+ '--results_path', help='path to the directory to save the results to'
+ )
+ parser.add_argument('--epoch_limit', help='how many epochs to run')
+ parser.add_argument(
+ '--resume', type=parse_bool, default=True,
+ help='whether to resume training from the last checkpoint'
+ )
+ parser.add_argument(
+ '--sample_rate', type=int,
+ help='sample rate of the training data and generated sound'
+ )
+ parser.add_argument(
+ '--n_samples', type=int,
+ help='number of samples to generate in each epoch'
+ )
+ parser.add_argument(
+ '--sample_length', type=int,
+ help='length of each generated sample (in samples)'
+ )
+ parser.add_argument(
+ '--loss_smoothing', type=float,
+ help='smoothing parameter of the exponential moving average over \
+ training loss, used in the log and in the loss plot'
+ )
+
+ parser.set_defaults(**default_params)
+
+ main(**vars(parser.parse_args()))
diff --git a/trainer/__init__.py b/trainer/__init__.py
new file mode 100644
index 0000000..7e2ea18
--- /dev/null
+++ b/trainer/__init__.py
@@ -0,0 +1,100 @@
+import torch
+from torch.autograd import Variable
+
+import heapq
+
+
+# Based on torch.utils.trainer.Trainer code.
+# Allows multiple inputs to the model, not all need to be Tensors.
+class Trainer(object):
+
+ def __init__(self, model, criterion, optimizer, dataset, cuda=False):
+ self.model = model
+ self.criterion = criterion
+ self.optimizer = optimizer
+ self.dataset = dataset
+ self.cuda = cuda
+ self.iterations = 0
+ self.epochs = 0
+ self.stats = {}
+ self.plugin_queues = {
+ 'iteration': [],
+ 'epoch': [],
+ 'batch': [],
+ 'update': [],
+ }
+
+ def register_plugin(self, plugin):
+ plugin.register(self)
+
+ intervals = plugin.trigger_interval
+ if not isinstance(intervals, list):
+ intervals = [intervals]
+ for (duration, unit) in intervals:
+ queue = self.plugin_queues[unit]
+ queue.append((duration, len(queue), plugin))
+
+ def call_plugins(self, queue_name, time, *args):
+ args = (time,) + args
+ queue = self.plugin_queues[queue_name]
+ if len(queue) == 0:
+ return
+ while queue[0][0] <= time:
+ plugin = queue[0][2]
+ getattr(plugin, queue_name)(*args)
+ for trigger in plugin.trigger_interval:
+ if trigger[1] == queue_name:
+ interval = trigger[0]
+ new_item = (time + interval, queue[0][1], plugin)
+ heapq.heappushpop(queue, new_item)
+
+ def run(self, epochs=1):
+ for q in self.plugin_queues.values():
+ heapq.heapify(q)
+
+ for self.epochs in range(self.epochs + 1, self.epochs + epochs + 1):
+ self.train()
+ self.call_plugins('epoch', self.epochs)
+
+ def train(self):
+ for (self.iterations, data) in \
+ enumerate(self.dataset, self.iterations + 1):
+ batch_inputs = data[: -1]
+ batch_target = data[-1]
+ self.call_plugins(
+ 'batch', self.iterations, batch_inputs, batch_target
+ )
+
+ def wrap(input):
+ if torch.is_tensor(input):
+ input = Variable(input)
+ if self.cuda:
+ input = input.cuda()
+ return input
+ batch_inputs = list(map(wrap, batch_inputs))
+
+ batch_target = Variable(batch_target)
+ if self.cuda:
+ batch_target = batch_target.cuda()
+
+ plugin_data = [None, None]
+
+ def closure():
+ batch_output = self.model(*batch_inputs)
+
+ loss = self.criterion(batch_output, batch_target)
+ loss.backward()
+
+ if plugin_data[0] is None:
+ plugin_data[0] = batch_output.data
+ plugin_data[1] = loss.data
+
+ return loss
+
+ self.optimizer.zero_grad()
+ self.optimizer.step(closure)
+ self.call_plugins(
+ 'iteration', self.iterations, batch_inputs, batch_target,
+ *plugin_data
+ )
+ self.call_plugins('update', self.iterations, self.model)
diff --git a/trainer/plugins.py b/trainer/plugins.py
new file mode 100644
index 0000000..552804a
--- /dev/null
+++ b/trainer/plugins.py
@@ -0,0 +1,267 @@
+import matplotlib
+matplotlib.use('Agg')
+
+from model import Generator
+
+import torch
+from torch.autograd import Variable
+from torch.utils.trainer.plugins.plugin import Plugin
+from torch.utils.trainer.plugins.monitor import Monitor
+from torch.utils.trainer.plugins import LossMonitor
+
+from librosa.output import write_wav
+from matplotlib import pyplot
+
+from glob import glob
+import os
+import pickle
+import time
+
+
+class TrainingLossMonitor(LossMonitor):
+
+ stat_name = 'training_loss'
+
+
+class ValidationPlugin(Plugin):
+
+ def __init__(self, val_dataset, test_dataset):
+ super().__init__([(1, 'epoch')])
+ self.val_dataset = val_dataset
+ self.test_dataset = test_dataset
+
+ def register(self, trainer):
+ self.trainer = trainer
+ val_stats = self.trainer.stats.setdefault('validation_loss', {})
+ val_stats['log_epoch_fields'] = ['{last:.4f}']
+ test_stats = self.trainer.stats.setdefault('test_loss', {})
+ test_stats['log_epoch_fields'] = ['{last:.4f}']
+
+ def epoch(self, idx):
+ self.trainer.model.eval()
+
+ val_stats = self.trainer.stats.setdefault('validation_loss', {})
+ val_stats['last'] = self._evaluate(self.val_dataset)
+ test_stats = self.trainer.stats.setdefault('test_loss', {})
+ test_stats['last'] = self._evaluate(self.test_dataset)
+
+ self.trainer.model.train()
+
+ def _evaluate(self, dataset):
+ loss_sum = 0
+ n_examples = 0
+ for data in dataset:
+ batch_inputs = data[: -1]
+ batch_target = data[-1]
+ batch_size = batch_target.size()[0]
+
+ def wrap(input):
+ if torch.is_tensor(input):
+ input = Variable(input, volatile=True)
+ if self.trainer.cuda:
+ input = input.cuda()
+ return input
+ batch_inputs = list(map(wrap, batch_inputs))
+
+ batch_target = Variable(batch_target, volatile=True)
+ if self.trainer.cuda:
+ batch_target = batch_target.cuda()
+
+ batch_output = self.trainer.model(*batch_inputs)
+ loss_sum += self.trainer.criterion(batch_output, batch_target) \
+ .data[0] * batch_size
+
+ n_examples += batch_size
+
+ return loss_sum / n_examples
+
+
+class AbsoluteTimeMonitor(Monitor):
+
+ stat_name = 'time'
+
+ def __init__(self, *args, **kwargs):
+ kwargs.setdefault('unit', 's')
+ kwargs.setdefault('precision', 0)
+ kwargs.setdefault('running_average', False)
+ kwargs.setdefault('epoch_average', False)
+ super(AbsoluteTimeMonitor, self).__init__(*args, **kwargs)
+ self.start_time = None
+
+ def _get_value(self, *args):
+ if self.start_time is None:
+ self.start_time = time.time()
+ return time.time() - self.start_time
+
+
+class SaverPlugin(Plugin):
+
+ last_pattern = 'ep{}-it{}'
+ best_pattern = 'best-ep{}-it{}'
+
+ def __init__(self, checkpoints_path, keep_old_checkpoints):
+ super().__init__([(1, 'epoch')])
+ self.checkpoints_path = checkpoints_path
+ self.keep_old_checkpoints = keep_old_checkpoints
+ self._best_val_loss = float('+inf')
+
+ def register(self, trainer):
+ self.trainer = trainer
+
+ def epoch(self, epoch_index):
+ if not self.keep_old_checkpoints:
+ self._clear(self.last_pattern.format('*', '*'))
+ torch.save(
+ self.trainer.model.state_dict(),
+ os.path.join(
+ self.checkpoints_path,
+ self.last_pattern.format(epoch_index, self.trainer.iterations)
+ )
+ )
+
+ cur_val_loss = self.trainer.stats['validation_loss']['last']
+ if cur_val_loss < self._best_val_loss:
+ self._clear(self.best_pattern.format('*', '*'))
+ torch.save(
+ self.trainer.model.state_dict(),
+ os.path.join(
+ self.checkpoints_path,
+ self.best_pattern.format(
+ epoch_index, self.trainer.iterations
+ )
+ )
+ )
+ self._best_val_loss = cur_val_loss
+
+ def _clear(self, pattern):
+ pattern = os.path.join(self.checkpoints_path, pattern)
+ for file_name in glob(pattern):
+ os.remove(file_name)
+
+
+class GeneratorPlugin(Plugin):
+
+ pattern = 'ep{}-s{}.wav'
+
+ def __init__(self, samples_path, n_samples, sample_length, sample_rate):
+ super().__init__([(1, 'epoch')])
+ self.samples_path = samples_path
+ self.n_samples = n_samples
+ self.sample_length = sample_length
+ self.sample_rate = sample_rate
+
+ def register(self, trainer):
+ self.generate = Generator(trainer.model.model, trainer.cuda)
+
+ def epoch(self, epoch_index):
+ samples = self.generate(self.n_samples, self.sample_length) \
+ .cpu().float().numpy()
+ for i in range(self.n_samples):
+ write_wav(
+ os.path.join(
+ self.samples_path, self.pattern.format(epoch_index, i + 1)
+ ),
+ samples[i, :], sr=self.sample_rate, norm=True
+ )
+
+
+class StatsPlugin(Plugin):
+
+ data_file_name = 'stats.pkl'
+ plot_pattern = '{}.svg'
+
+ def __init__(self, results_path, iteration_fields, epoch_fields, plots):
+ super().__init__([(1, 'iteration'), (1, 'epoch')])
+ self.results_path = results_path
+
+ self.iteration_fields = self._fields_to_pairs(iteration_fields)
+ self.epoch_fields = self._fields_to_pairs(epoch_fields)
+ self.plots = plots
+ self.data = {
+ 'iterations': {
+ field: []
+ for field in self.iteration_fields + [('iteration', 'last')]
+ },
+ 'epochs': {
+ field: []
+ for field in self.epoch_fields + [('iteration', 'last')]
+ }
+ }
+
+ def register(self, trainer):
+ self.trainer = trainer
+
+ def iteration(self, *args):
+ for (field, stat) in self.iteration_fields:
+ self.data['iterations'][field, stat].append(
+ self.trainer.stats[field][stat]
+ )
+
+ self.data['iterations']['iteration', 'last'].append(
+ self.trainer.iterations
+ )
+
+ def epoch(self, epoch_index):
+ for (field, stat) in self.epoch_fields:
+ self.data['epochs'][field, stat].append(
+ self.trainer.stats[field][stat]
+ )
+
+ self.data['epochs']['iteration', 'last'].append(
+ self.trainer.iterations
+ )
+
+ data_file_path = os.path.join(self.results_path, self.data_file_name)
+ with open(data_file_path, 'wb') as f:
+ pickle.dump(self.data, f)
+
+ for (name, info) in self.plots.items():
+ x_field = self._field_to_pair(info['x'])
+
+ try:
+ y_fields = info['ys']
+ except KeyError:
+ y_fields = [info['y']]
+
+ labels = list(map(
+ lambda x: ' '.join(x) if type(x) is tuple else x,
+ y_fields
+ ))
+ y_fields = self._fields_to_pairs(y_fields)
+
+ try:
+ formats = info['formats']
+ except KeyError:
+ formats = [''] * len(y_fields)
+
+ pyplot.gcf().clear()
+
+ for (y_field, format, label) in zip(y_fields, formats, labels):
+ if y_field in self.iteration_fields:
+ part_name = 'iterations'
+ else:
+ part_name = 'epochs'
+
+ xs = self.data[part_name][x_field]
+ ys = self.data[part_name][y_field]
+
+ pyplot.plot(xs, ys, format, label=label)
+
+ if 'log_y' in info and info['log_y']:
+ pyplot.yscale('log')
+
+ pyplot.legend()
+ pyplot.savefig(
+ os.path.join(self.results_path, self.plot_pattern.format(name))
+ )
+
+ @staticmethod
+ def _field_to_pair(field):
+ if type(field) is tuple:
+ return field
+ else:
+ return (field, 'last')
+
+ @classmethod
+ def _fields_to_pairs(cls, fields):
+ return list(map(cls._field_to_pair, fields))
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..320fe95
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,20 @@
+import torch
+from torch import nn
+import numpy as np
+
+
+EPSILON = 1e-5
+
+def linear_quantize(samples, q_levels):
+ samples = samples.clone()
+ samples -= samples.min(dim=-1)[0].expand_as(samples)
+ samples /= samples.max(dim=-1)[0].expand_as(samples)
+ samples *= q_levels - EPSILON
+ samples += EPSILON / 2
+ return samples.long()
+
+def linear_dequantize(samples, q_levels):
+ return samples.float() / (q_levels / 2) - 1
+
+def q_zero(q_levels):
+ return q_levels // 2