diff options
| -rw-r--r-- | .gitignore | 5 | ||||
| -rw-r--r-- | README.md | 33 | ||||
| -rw-r--r-- | dataset.py | 68 | ||||
| -rwxr-xr-x | datasets/download-from-youtube.sh | 30 | ||||
| -rw-r--r-- | model.py | 286 | ||||
| -rw-r--r-- | nn.py | 70 | ||||
| -rw-r--r-- | optim.py | 21 | ||||
| -rw-r--r-- | sample.py | 34 | ||||
| -rw-r--r-- | train.py | 316 | ||||
| -rw-r--r-- | trainer/__init__.py | 100 | ||||
| -rw-r--r-- | trainer/plugins.py | 267 | ||||
| -rw-r--r-- | utils.py | 20 |
12 files changed, 1249 insertions, 1 deletions
@@ -87,3 +87,8 @@ ENV/ # Rope project settings .ropeproject + +# vim temporary files +*~ +*.swp +*.swo @@ -1 +1,32 @@ -# samplernn-pytorch
\ No newline at end of file +# samplernn-pytorch + +A PyTorch implementation of [SampleRNN: An Unconditional End-to-End Neural Audio Generation Model](https://arxiv.org/abs/1612.07837). + + + +It's based on the reference implementation in Theano: https://github.com/soroushmehr/sampleRNN_ICLR2017. Unlike the Theano version, our code allows training models with arbitrary number of tiers, whereas the original implementation allows maximum 3 tiers. However it doesn't have weight normalization and doesn't allow using LSTM units (only GRU). For more details and motivation behind rewriting this model to PyTorch, see our blog post: http://deepsound.io/samplernn_pytorch.html. + +## Dependencies + +This code requires Python 3.5+ and PyTorch 0.1.12+. Installation instructions for PyTorch are available on their website: http://pytorch.org/. You can install the rest of the dependencies by running `pip install -r requirements.txt`. + +## Datasets + +We provide a script for creating datasets from YouTube single-video mixes. It downloads a mix, converts it to wav and splits it into equal-length chunks. To run it you need youtube-dl (a recent version; the latest version from pip should be okay) and ffmpeg. To create an example dataset - 4 hours of piano music split into 8 second chunks, run: + +``` +cd datasets +./download-from-youtube.sh "https://www.youtube.com/watch?v=EhO_MrRfftU" 8 piano +``` + +You can also prepare a dataset yourself. It should be a directory in `datasets/` filled with equal-length wav files. Or you can create your own dataset format by subclassing `torch.utils.data.Dataset`. It's easy, take a look at `dataset.FolderDataset` in this repo for an example. + +## Training + +To train the model you need to run `train.py`. All model hyperparameters are settable in the command line. Most hyperparameters have sensible default values, so you don't need to provide all od them. Run `python train.py -h` for details. To train on the `piano` dataset using the best hyperparameters we've found, run: + +``` +python train.py --exp TEST --frame_sizes 16 4 --n_rnn 2 --dataset piano +``` + +The results - training log, loss plots, model checkpoints and generated samples will be saved in `results/`. diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..3bb035f --- /dev/null +++ b/dataset.py @@ -0,0 +1,68 @@ +import utils + +import torch +from torch.utils.data import ( + Dataset, DataLoader as DataLoaderBase +) + +from librosa.core import load +from natsort import natsorted + +from os import listdir +from os.path import join + + +class FolderDataset(Dataset): + + def __init__(self, path, overlap_len, q_levels, ratio_min=0, ratio_max=1): + super().__init__() + self.overlap_len = overlap_len + self.q_levels = q_levels + file_names = natsorted( + [join(path, file_name) for file_name in listdir(path)] + ) + self.file_names = file_names[ + int(ratio_min * len(file_names)) : int(ratio_max * len(file_names)) + ] + + def __getitem__(self, index): + (seq, _) = load(self.file_names[index], sr=None, mono=True) + return torch.cat([ + torch.LongTensor(self.overlap_len) \ + .fill_(utils.q_zero(self.q_levels)), + utils.linear_quantize( + torch.from_numpy(seq), self.q_levels + ) + ]) + + def __len__(self): + return len(self.file_names) + + +class DataLoader(DataLoaderBase): + + def __init__(self, dataset, batch_size, seq_len, overlap_len, + *args, **kwargs): + super().__init__(dataset, batch_size, *args, **kwargs) + self.seq_len = seq_len + self.overlap_len = overlap_len + + def __iter__(self): + for batch in super().__iter__(): + (batch_size, n_samples) = batch.size() + + reset = True + + for seq_begin in range(self.overlap_len, n_samples, self.seq_len): + from_index = seq_begin - self.overlap_len + to_index = seq_begin + self.seq_len + sequences = batch[:, from_index : to_index] + input_sequences = sequences[:, : -1] + target_sequences = sequences[:, self.overlap_len :] + + yield (input_sequences, reset, target_sequences) + + reset = False + + def __len__(self): + raise NotImplementedError() diff --git a/datasets/download-from-youtube.sh b/datasets/download-from-youtube.sh new file mode 100755 index 0000000..2bcec33 --- /dev/null +++ b/datasets/download-from-youtube.sh @@ -0,0 +1,30 @@ +#!/bin/sh + +if [ "$#" -ne 3 ]; then + echo "Usage: $0 <youtube url> <chunk size in seconds> <dataset path>" + exit +fi + +url=$1 +chunk_size=$2 +dataset_path=$3 + +downloaded=".temp" +rm -f $downloaded +format=$(youtube-dl -F $url | grep audio | sed -r 's|([0-9]+).*|\1|g' | tail -n 1) +youtube-dl $url -f $format -o $downloaded + +converted=".temp2.wav" +rm -f $converted +ffmpeg -i $downloaded -ac 1 -ab 16k -ar 16000 $converted +rm -f $downloaded + +mkdir $dataset_path +length=$(ffprobe -i $converted -show_entries format=duration -v quiet -of csv="p=0") +end=$(echo "$length / $chunk_size - 1" | bc) +echo "splitting..." +for i in $(seq 0 $end); do + ffmpeg -hide_banner -loglevel error -ss $(($i * $chunk_size)) -t $chunk_size -i $converted "$dataset_path/$i.wav" +done +echo "done" +rm -f $converted diff --git a/model.py b/model.py new file mode 100644 index 0000000..d2db1a9 --- /dev/null +++ b/model.py @@ -0,0 +1,286 @@ +import nn +import utils + +import torch +from torch.nn import functional as F +from torch.nn import init + +import numpy as np + + +class SampleRNN(torch.nn.Module): + + def __init__(self, frame_sizes, n_rnn, dim, learn_h0, q_levels): + super().__init__() + + self.dim = dim + self.q_levels = q_levels + + ns_frame_samples = map(int, np.cumprod(frame_sizes)) + self.frame_level_rnns = torch.nn.ModuleList([ + FrameLevelRNN( + frame_size, n_frame_samples, n_rnn, dim, learn_h0 + ) + for (frame_size, n_frame_samples) in zip( + frame_sizes, ns_frame_samples + ) + ]) + + self.sample_level_mlp = SampleLevelMLP(frame_sizes[0], dim, q_levels) + + @property + def lookback(self): + return self.frame_level_rnns[-1].n_frame_samples + + +class FrameLevelRNN(torch.nn.Module): + + def __init__(self, frame_size, n_frame_samples, n_rnn, dim, + learn_h0): + super().__init__() + + self.frame_size = frame_size + self.n_frame_samples = n_frame_samples + self.dim = dim + + h0 = torch.zeros(n_rnn, dim) + if learn_h0: + self.h0 = torch.nn.Parameter(h0) + else: + self.register_buffer('h0', torch.autograd.Variable(h0)) + + self.input_expand = torch.nn.Conv1d( + in_channels=n_frame_samples, + out_channels=dim, + kernel_size=1 + ) + init.kaiming_uniform(self.input_expand.weight) + init.constant(self.input_expand.bias, 0) + + self.rnn = torch.nn.GRU( + input_size=dim, + hidden_size=dim, + num_layers=n_rnn, + batch_first=True + ) + for i in range(n_rnn): + nn.concat_init( + getattr(self.rnn, 'weight_ih_l{}'.format(i)), + [nn.lecun_uniform, nn.lecun_uniform, nn.lecun_uniform] + ) + init.constant(getattr(self.rnn, 'bias_ih_l{}'.format(i)), 0) + + nn.concat_init( + getattr(self.rnn, 'weight_hh_l{}'.format(i)), + [nn.lecun_uniform, nn.lecun_uniform, init.orthogonal] + ) + init.constant(getattr(self.rnn, 'bias_hh_l{}'.format(i)), 0) + + self.upsampling = nn.LearnedUpsampling1d( + in_channels=dim, + out_channels=dim, + kernel_size=frame_size + ) + init.uniform( + self.upsampling.conv_t.weight, -np.sqrt(6 / dim), np.sqrt(6 / dim) + ) + init.constant(self.upsampling.bias, 0) + + def forward(self, prev_samples, upper_tier_conditioning, hidden): + (batch_size, _, _) = prev_samples.size() + + input = self.input_expand( + prev_samples.permute(0, 2, 1) + ).permute(0, 2, 1) + if upper_tier_conditioning is not None: + input += upper_tier_conditioning + + reset = hidden is None + + if hidden is None: + (n_rnn, _) = self.h0.size() + hidden = self.h0.unsqueeze(1) \ + .expand(n_rnn, batch_size, self.dim) \ + .contiguous() + + (output, hidden) = self.rnn(input, hidden) + + output = self.upsampling( + output.permute(0, 2, 1) + ).permute(0, 2, 1) + return (output, hidden) + + +class SampleLevelMLP(torch.nn.Module): + + def __init__(self, frame_size, dim, q_levels): + super().__init__() + + self.q_levels = q_levels + + self.embedding = torch.nn.Embedding( + self.q_levels, + self.q_levels + ) + + self.input = torch.nn.Conv1d( + in_channels=q_levels, + out_channels=dim, + kernel_size=frame_size, + bias=False + ) + init.kaiming_uniform(self.input.weight) + + self.hidden = torch.nn.Conv1d( + in_channels=dim, + out_channels=dim, + kernel_size=1 + ) + init.kaiming_uniform(self.hidden.weight) + init.constant(self.hidden.bias, 0) + + self.output = torch.nn.Conv1d( + in_channels=dim, + out_channels=q_levels, + kernel_size=1 + ) + nn.lecun_uniform(self.output.weight) + init.constant(self.output.bias, 0) + + def forward(self, prev_samples, upper_tier_conditioning): + (batch_size, _, _) = upper_tier_conditioning.size() + + prev_samples = self.embedding( + prev_samples.contiguous().view(-1) + ).view( + batch_size, -1, self.q_levels + ) + + prev_samples = prev_samples.permute(0, 2, 1) + upper_tier_conditioning = upper_tier_conditioning.permute(0, 2, 1) + + x = F.relu(self.input(prev_samples) + upper_tier_conditioning) + x = F.relu(self.hidden(x)) + x = self.output(x).permute(0, 2, 1).contiguous() + + return F.log_softmax(x.view(-1, self.q_levels)) \ + .view(batch_size, -1, self.q_levels) + + +class Runner: + + def __init__(self, model): + super().__init__() + self.model = model + self.reset_hidden_states() + + def reset_hidden_states(self): + self.hidden_states = {rnn: None for rnn in self.model.frame_level_rnns} + + def run_rnn(self, rnn, prev_samples, upper_tier_conditioning): + (output, new_hidden) = rnn( + prev_samples, upper_tier_conditioning, self.hidden_states[rnn] + ) + self.hidden_states[rnn] = new_hidden.detach() + return output + + +class Predictor(Runner, torch.nn.Module): + + def __init__(self, model): + super().__init__(model) + + def forward(self, input_sequences, reset): + if reset: + self.reset_hidden_states() + + (batch_size, _) = input_sequences.size() + + upper_tier_conditioning = None + for rnn in reversed(self.model.frame_level_rnns): + from_index = self.model.lookback - rnn.n_frame_samples + to_index = -rnn.n_frame_samples + 1 + prev_samples = 2 * utils.linear_dequantize( + input_sequences[:, from_index : to_index], + self.model.q_levels + ) + prev_samples = prev_samples.contiguous().view( + batch_size, -1, rnn.n_frame_samples + ) + + upper_tier_conditioning = self.run_rnn( + rnn, prev_samples, upper_tier_conditioning + ) + + bottom_frame_size = self.model.frame_level_rnns[0].frame_size + mlp_input_sequences = input_sequences \ + [:, self.model.lookback - bottom_frame_size :] + + return self.model.sample_level_mlp( + mlp_input_sequences, upper_tier_conditioning + ) + + +class Generator(Runner): + + def __init__(self, model, cuda=False): + super().__init__(model) + self.cuda = cuda + + def __call__(self, n_seqs, seq_len): + # generation doesn't work with CUDNN for some reason + torch.backends.cudnn.enabled = False + + self.reset_hidden_states() + + bottom_frame_size = self.model.frame_level_rnns[0].n_frame_samples + sequences = torch.LongTensor(n_seqs, self.model.lookback + seq_len) \ + .fill_(utils.q_zero(self.model.q_levels)) + frame_level_outputs = [None for _ in self.model.frame_level_rnns] + + for i in range(self.model.lookback, self.model.lookback + seq_len): + for (tier_index, rnn) in \ + reversed(list(enumerate(self.model.frame_level_rnns))): + if i % rnn.n_frame_samples != 0: + continue + + prev_samples = torch.autograd.Variable( + 2 * utils.linear_dequantize( + sequences[:, i - rnn.n_frame_samples : i], + self.model.q_levels + ).unsqueeze(1), + volatile=True + ) + if self.cuda: + prev_samples = prev_samples.cuda() + + if tier_index == len(self.model.frame_level_rnns) - 1: + upper_tier_conditioning = None + else: + frame_index = (i // rnn.n_frame_samples) % \ + self.model.frame_level_rnns[tier_index + 1].frame_size + upper_tier_conditioning = \ + frame_level_outputs[tier_index + 1][:, frame_index, :] \ + .unsqueeze(1) + + frame_level_outputs[tier_index] = self.run_rnn( + rnn, prev_samples, upper_tier_conditioning + ) + + prev_samples = torch.autograd.Variable( + sequences[:, i - bottom_frame_size : i], + volatile=True + ) + if self.cuda: + prev_samples = prev_samples.cuda() + upper_tier_conditioning = \ + frame_level_outputs[0][:, i % bottom_frame_size, :] \ + .unsqueeze(1) + sample_dist = self.model.sample_level_mlp( + prev_samples, upper_tier_conditioning + ).squeeze(1).exp_().data + sequences[:, i] = sample_dist.multinomial(1).squeeze(1) + + torch.backends.cudnn.enabled = True + + return sequences[:, self.model.lookback :] @@ -0,0 +1,70 @@ +import torch +from torch import nn + +import math + + +class LearnedUpsampling1d(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + super().__init__() + + self.conv_t = nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=kernel_size, + bias=False + ) + + if bias: + self.bias = nn.Parameter( + torch.FloatTensor(out_channels, kernel_size) + ) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + self.conv_t.reset_parameters() + nn.init.constant(self.bias, 0) + + def forward(self, input): + (batch_size, _, length) = input.size() + (kernel_size,) = self.conv_t.kernel_size + bias = self.bias.unsqueeze(0).unsqueeze(2).expand( + batch_size, self.conv_t.out_channels, + length, kernel_size + ).contiguous().view( + batch_size, self.conv_t.out_channels, + length * kernel_size + ) + return self.conv_t(input) + bias + + +def lecun_uniform(tensor): + fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in') + nn.init.uniform(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) + + +def concat_init(tensor, inits): + try: + tensor = tensor.data + except AttributeError: + pass + + (length, fan_out) = tensor.size() + fan_in = length // len(inits) + + chunk = tensor.new(fan_in, fan_out) + for (i, init) in enumerate(inits): + init(chunk) + tensor[i * fan_in : (i + 1) * fan_in, :] = chunk + + +def sequence_nll_loss_bits(input, target, *args, **kwargs): + (_, _, n_classes) = input.size() + return nn.functional.nll_loss( + input.view(-1, n_classes), target.view(-1), *args, **kwargs + ) * math.log(math.e, 2) diff --git a/optim.py b/optim.py new file mode 100644 index 0000000..95f9df2 --- /dev/null +++ b/optim.py @@ -0,0 +1,21 @@ +from torch.nn.functional import hardtanh + + +def gradient_clipping(optimizer, min=-1, max=1): + + class OptimizerWrapper(object): + + def step(self, closure): + def closure_wrapper(): + loss = closure() + for group in optimizer.param_groups: + for p in group['params']: + hardtanh(p.grad, min, max, inplace=True) + return loss + + return optimizer.step(closure_wrapper) + + def __getattr__(self, attr): + return getattr(optimizer, attr) + + return OptimizerWrapper() diff --git a/sample.py b/sample.py new file mode 100644 index 0000000..775fd24 --- /dev/null +++ b/sample.py @@ -0,0 +1,34 @@ +from model import SampleRNN, Predictor, Generator +from trainer import Trainer, sequence_nll_loss +from dataset import FolderDataset, DataLoader + +import torch +from torch.utils.trainer import plugins + +from librosa.output import write_wav + +from time import time + + +def main(): + model = SampleRNN( + frame_sizes=[16, 4], n_rnn=1, dim=1024, learn_h0=True, q_levels=256 + ) + predictor = Predictor(model).cuda() + predictor.load_state_dict(torch.load('model.tar')) + + generator = Generator(predictor.model, cuda=True) + + t = time() + samples = generator(5, 16000) + print('generated in {}s'.format(time() - t)) + + write_wav( + 'sample.wav', + samples.cpu().float().numpy()[0, :], + sr=16000, + norm=True + ) + +if __name__ == '__main__': + main() diff --git a/train.py b/train.py new file mode 100644 index 0000000..c47cd4b --- /dev/null +++ b/train.py @@ -0,0 +1,316 @@ +from model import SampleRNN, Predictor +from optim import gradient_clipping +from nn import sequence_nll_loss_bits +from trainer import Trainer +from trainer.plugins import ( + TrainingLossMonitor, ValidationPlugin, AbsoluteTimeMonitor, SaverPlugin, + GeneratorPlugin, StatsPlugin +) +from dataset import FolderDataset, DataLoader + +import torch +from torch.utils.trainer.plugins import Logger + +from natsort import natsorted + +from functools import reduce +import os +import shutil +import sys +from glob import glob +import re +import argparse + + +default_params = { + # model parameters + 'n_rnn': 1, + 'dim': 1024, + 'learn_h0': True, + 'q_levels': 256, + 'seq_len': 1024, + 'batch_size': 128, + 'val_frac': 0.1, + 'test_frac': 0.1, + + # training parameters + 'keep_old_checkpoints': False, + 'datasets_path': 'datasets', + 'results_path': 'results', + 'epoch_limit': 1000, + 'resume': True, + 'sample_rate': 16000, + 'n_samples': 1, + 'sample_length': 80000, + 'loss_smoothing': 0.99 +} + +tag_params = [ + 'exp', 'frame_sizes', 'n_rnn', 'dim', 'learn_h0', 'q_levels', 'seq_len', + 'batch_size', 'dataset', 'val_frac', 'test_frac' +] + +def make_tag(params): + def to_string(value): + if isinstance(value, bool): + return 'T' if value else 'F' + elif isinstance(value, list): + return ','.join(map(to_string, value)) + else: + return str(value) + + return '-'.join( + key + ':' + to_string(params[key]) + for key in tag_params + if key not in default_params or params[key] != default_params[key] + ) + +def setup_results_dir(params): + def ensure_dir_exists(path): + if not os.path.exists(path): + os.makedirs(path) + + tag = make_tag(params) + results_path = os.path.abspath(params['results_path']) + ensure_dir_exists(results_path) + results_path = os.path.join(results_path, tag) + if not os.path.exists(results_path): + os.makedirs(results_path) + elif not params['resume']: + shutil.rmtree(results_path) + os.makedirs(results_path) + + for subdir in ['checkpoints', 'samples']: + ensure_dir_exists(os.path.join(results_path, subdir)) + + return results_path + +def load_last_checkpoint(checkpoints_path): + checkpoints_pattern = os.path.join( + checkpoints_path, SaverPlugin.last_pattern.format('*', '*') + ) + checkpoint_paths = natsorted(glob(checkpoints_pattern)) + if len(checkpoint_paths) > 0: + checkpoint_path = checkpoint_paths[-1] + checkpoint_name = os.path.basename(checkpoint_path) + match = re.match( + SaverPlugin.last_pattern.format(r'(\d+)', r'(\d+)'), + checkpoint_name + ) + epoch = int(match.group(1)) + iteration = int(match.group(2)) + return (torch.load(checkpoint_path), epoch, iteration) + else: + return None + +def tee_stdout(log_path): + log_file = open(log_path, 'a', 1) + stdout = sys.stdout + + class Tee: + + def write(self, string): + log_file.write(string) + stdout.write(string) + + def flush(self): + log_file.flush() + stdout.flush() + + sys.stdout = Tee() + +def make_data_loader(overlap_len, params): + path = os.path.join(params['datasets_path'], params['dataset']) + def data_loader(split_from, split_to, eval): + dataset = FolderDataset( + path, overlap_len, params['q_levels'], split_from, split_to + ) + return DataLoader( + dataset, + batch_size=params['batch_size'], + seq_len=params['seq_len'], + overlap_len=overlap_len, + shuffle=(not eval), + drop_last=(not eval) + ) + return data_loader + +def main(exp, frame_sizes, dataset, **params): + params = dict( + default_params, + exp=exp, frame_sizes=frame_sizes, dataset=dataset, + **params + ) + + results_path = setup_results_dir(params) + tee_stdout(os.path.join(results_path, 'log')) + + model = SampleRNN( + frame_sizes=params['frame_sizes'], + n_rnn=params['n_rnn'], + dim=params['dim'], + learn_h0=params['learn_h0'], + q_levels=params['q_levels'] + ).cuda() + predictor = Predictor(model).cuda() + + optimizer = gradient_clipping(torch.optim.Adam(predictor.parameters())) + + data_loader = make_data_loader(model.lookback, params) + test_split = 1 - params['test_frac'] + val_split = test_split - params['val_frac'] + + trainer = Trainer( + predictor, sequence_nll_loss_bits, optimizer, + data_loader(0, val_split, eval=False), + cuda=True + ) + + checkpoints_path = os.path.join(results_path, 'checkpoints') + checkpoint_data = load_last_checkpoint(checkpoints_path) + if checkpoint_data is not None: + (state_dict, epoch, iteration) = checkpoint_data + trainer.epochs = epoch + trainer.iterations = iteration + predictor.load_state_dict(state_dict) + + trainer.register_plugin(TrainingLossMonitor( + smoothing=params['loss_smoothing'] + )) + trainer.register_plugin(ValidationPlugin( + data_loader(val_split, test_split, eval=True), + data_loader(test_split, 1, eval=True) + )) + trainer.register_plugin(AbsoluteTimeMonitor()) + trainer.register_plugin(SaverPlugin( + checkpoints_path, params['keep_old_checkpoints'] + )) + trainer.register_plugin(GeneratorPlugin( + os.path.join(results_path, 'samples'), params['n_samples'], + params['sample_length'], params['sample_rate'] + )) + trainer.register_plugin( + Logger([ + 'training_loss', + 'validation_loss', + 'test_loss', + 'time' + ]) + ) + trainer.register_plugin(StatsPlugin( + results_path, + iteration_fields=[ + 'training_loss', + ('training_loss', 'running_avg'), + 'time' + ], + epoch_fields=[ + 'validation_loss', + 'test_loss', + 'time' + ], + plots={ + 'loss': { + 'x': 'iteration', + 'ys': [ + 'training_loss', + ('training_loss', 'running_avg'), + 'validation_loss', + 'test_loss', + ], + 'log_y': True + } + } + )) + trainer.run(params['epoch_limit']) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + argument_default=argparse.SUPPRESS + ) + + def parse_bool(arg): + arg = arg.lower() + if 'true'.startswith(arg): + return True + elif 'false'.startswith(arg): + return False + else: + raise ValueError() + + parser.add_argument('--exp', required=True, help='experiment name') + parser.add_argument( + '--frame_sizes', nargs='+', type=int, required=True, + help='frame sizes in terms of the number of lower tier frames, \ + starting from the lowest RNN tier' + ) + parser.add_argument( + '--dataset', required=True, + help='dataset name - name of a directory in the datasets path \ + (settable by --datasets_path)' + ) + parser.add_argument( + '--n_rnn', type=int, help='number of RNN layers in each tier' + ) + parser.add_argument( + '--dim', type=int, help='number of neurons in every RNN and MLP layer' + ) + parser.add_argument( + '--learn_h0', type=parse_bool, + help='whether to learn the initial states of RNNs' + ) + parser.add_argument( + '--q_levels', type=int, + help='number of bins in quantization of audio samples' + ) + parser.add_argument( + '--seq_len', type=int, + help='how many samples to include in each truncated BPTT pass' + ) + parser.add_argument('--batch_size', type=int, help='batch size') + parser.add_argument( + '--val_frac', type=float, + help='fraction of data to go into the validation set' + ) + parser.add_argument( + '--test_frac', type=float, + help='fraction of data to go into the test set' + ) + parser.add_argument( + '--keep_old_checkpoints', type=parse_bool, + help='whether to keep checkpoints from past epochs' + ) + parser.add_argument( + '--datasets_path', help='path to the directory containing datasets' + ) + parser.add_argument( + '--results_path', help='path to the directory to save the results to' + ) + parser.add_argument('--epoch_limit', help='how many epochs to run') + parser.add_argument( + '--resume', type=parse_bool, default=True, + help='whether to resume training from the last checkpoint' + ) + parser.add_argument( + '--sample_rate', type=int, + help='sample rate of the training data and generated sound' + ) + parser.add_argument( + '--n_samples', type=int, + help='number of samples to generate in each epoch' + ) + parser.add_argument( + '--sample_length', type=int, + help='length of each generated sample (in samples)' + ) + parser.add_argument( + '--loss_smoothing', type=float, + help='smoothing parameter of the exponential moving average over \ + training loss, used in the log and in the loss plot' + ) + + parser.set_defaults(**default_params) + + main(**vars(parser.parse_args())) diff --git a/trainer/__init__.py b/trainer/__init__.py new file mode 100644 index 0000000..7e2ea18 --- /dev/null +++ b/trainer/__init__.py @@ -0,0 +1,100 @@ +import torch +from torch.autograd import Variable + +import heapq + + +# Based on torch.utils.trainer.Trainer code. +# Allows multiple inputs to the model, not all need to be Tensors. +class Trainer(object): + + def __init__(self, model, criterion, optimizer, dataset, cuda=False): + self.model = model + self.criterion = criterion + self.optimizer = optimizer + self.dataset = dataset + self.cuda = cuda + self.iterations = 0 + self.epochs = 0 + self.stats = {} + self.plugin_queues = { + 'iteration': [], + 'epoch': [], + 'batch': [], + 'update': [], + } + + def register_plugin(self, plugin): + plugin.register(self) + + intervals = plugin.trigger_interval + if not isinstance(intervals, list): + intervals = [intervals] + for (duration, unit) in intervals: + queue = self.plugin_queues[unit] + queue.append((duration, len(queue), plugin)) + + def call_plugins(self, queue_name, time, *args): + args = (time,) + args + queue = self.plugin_queues[queue_name] + if len(queue) == 0: + return + while queue[0][0] <= time: + plugin = queue[0][2] + getattr(plugin, queue_name)(*args) + for trigger in plugin.trigger_interval: + if trigger[1] == queue_name: + interval = trigger[0] + new_item = (time + interval, queue[0][1], plugin) + heapq.heappushpop(queue, new_item) + + def run(self, epochs=1): + for q in self.plugin_queues.values(): + heapq.heapify(q) + + for self.epochs in range(self.epochs + 1, self.epochs + epochs + 1): + self.train() + self.call_plugins('epoch', self.epochs) + + def train(self): + for (self.iterations, data) in \ + enumerate(self.dataset, self.iterations + 1): + batch_inputs = data[: -1] + batch_target = data[-1] + self.call_plugins( + 'batch', self.iterations, batch_inputs, batch_target + ) + + def wrap(input): + if torch.is_tensor(input): + input = Variable(input) + if self.cuda: + input = input.cuda() + return input + batch_inputs = list(map(wrap, batch_inputs)) + + batch_target = Variable(batch_target) + if self.cuda: + batch_target = batch_target.cuda() + + plugin_data = [None, None] + + def closure(): + batch_output = self.model(*batch_inputs) + + loss = self.criterion(batch_output, batch_target) + loss.backward() + + if plugin_data[0] is None: + plugin_data[0] = batch_output.data + plugin_data[1] = loss.data + + return loss + + self.optimizer.zero_grad() + self.optimizer.step(closure) + self.call_plugins( + 'iteration', self.iterations, batch_inputs, batch_target, + *plugin_data + ) + self.call_plugins('update', self.iterations, self.model) diff --git a/trainer/plugins.py b/trainer/plugins.py new file mode 100644 index 0000000..552804a --- /dev/null +++ b/trainer/plugins.py @@ -0,0 +1,267 @@ +import matplotlib +matplotlib.use('Agg') + +from model import Generator + +import torch +from torch.autograd import Variable +from torch.utils.trainer.plugins.plugin import Plugin +from torch.utils.trainer.plugins.monitor import Monitor +from torch.utils.trainer.plugins import LossMonitor + +from librosa.output import write_wav +from matplotlib import pyplot + +from glob import glob +import os +import pickle +import time + + +class TrainingLossMonitor(LossMonitor): + + stat_name = 'training_loss' + + +class ValidationPlugin(Plugin): + + def __init__(self, val_dataset, test_dataset): + super().__init__([(1, 'epoch')]) + self.val_dataset = val_dataset + self.test_dataset = test_dataset + + def register(self, trainer): + self.trainer = trainer + val_stats = self.trainer.stats.setdefault('validation_loss', {}) + val_stats['log_epoch_fields'] = ['{last:.4f}'] + test_stats = self.trainer.stats.setdefault('test_loss', {}) + test_stats['log_epoch_fields'] = ['{last:.4f}'] + + def epoch(self, idx): + self.trainer.model.eval() + + val_stats = self.trainer.stats.setdefault('validation_loss', {}) + val_stats['last'] = self._evaluate(self.val_dataset) + test_stats = self.trainer.stats.setdefault('test_loss', {}) + test_stats['last'] = self._evaluate(self.test_dataset) + + self.trainer.model.train() + + def _evaluate(self, dataset): + loss_sum = 0 + n_examples = 0 + for data in dataset: + batch_inputs = data[: -1] + batch_target = data[-1] + batch_size = batch_target.size()[0] + + def wrap(input): + if torch.is_tensor(input): + input = Variable(input, volatile=True) + if self.trainer.cuda: + input = input.cuda() + return input + batch_inputs = list(map(wrap, batch_inputs)) + + batch_target = Variable(batch_target, volatile=True) + if self.trainer.cuda: + batch_target = batch_target.cuda() + + batch_output = self.trainer.model(*batch_inputs) + loss_sum += self.trainer.criterion(batch_output, batch_target) \ + .data[0] * batch_size + + n_examples += batch_size + + return loss_sum / n_examples + + +class AbsoluteTimeMonitor(Monitor): + + stat_name = 'time' + + def __init__(self, *args, **kwargs): + kwargs.setdefault('unit', 's') + kwargs.setdefault('precision', 0) + kwargs.setdefault('running_average', False) + kwargs.setdefault('epoch_average', False) + super(AbsoluteTimeMonitor, self).__init__(*args, **kwargs) + self.start_time = None + + def _get_value(self, *args): + if self.start_time is None: + self.start_time = time.time() + return time.time() - self.start_time + + +class SaverPlugin(Plugin): + + last_pattern = 'ep{}-it{}' + best_pattern = 'best-ep{}-it{}' + + def __init__(self, checkpoints_path, keep_old_checkpoints): + super().__init__([(1, 'epoch')]) + self.checkpoints_path = checkpoints_path + self.keep_old_checkpoints = keep_old_checkpoints + self._best_val_loss = float('+inf') + + def register(self, trainer): + self.trainer = trainer + + def epoch(self, epoch_index): + if not self.keep_old_checkpoints: + self._clear(self.last_pattern.format('*', '*')) + torch.save( + self.trainer.model.state_dict(), + os.path.join( + self.checkpoints_path, + self.last_pattern.format(epoch_index, self.trainer.iterations) + ) + ) + + cur_val_loss = self.trainer.stats['validation_loss']['last'] + if cur_val_loss < self._best_val_loss: + self._clear(self.best_pattern.format('*', '*')) + torch.save( + self.trainer.model.state_dict(), + os.path.join( + self.checkpoints_path, + self.best_pattern.format( + epoch_index, self.trainer.iterations + ) + ) + ) + self._best_val_loss = cur_val_loss + + def _clear(self, pattern): + pattern = os.path.join(self.checkpoints_path, pattern) + for file_name in glob(pattern): + os.remove(file_name) + + +class GeneratorPlugin(Plugin): + + pattern = 'ep{}-s{}.wav' + + def __init__(self, samples_path, n_samples, sample_length, sample_rate): + super().__init__([(1, 'epoch')]) + self.samples_path = samples_path + self.n_samples = n_samples + self.sample_length = sample_length + self.sample_rate = sample_rate + + def register(self, trainer): + self.generate = Generator(trainer.model.model, trainer.cuda) + + def epoch(self, epoch_index): + samples = self.generate(self.n_samples, self.sample_length) \ + .cpu().float().numpy() + for i in range(self.n_samples): + write_wav( + os.path.join( + self.samples_path, self.pattern.format(epoch_index, i + 1) + ), + samples[i, :], sr=self.sample_rate, norm=True + ) + + +class StatsPlugin(Plugin): + + data_file_name = 'stats.pkl' + plot_pattern = '{}.svg' + + def __init__(self, results_path, iteration_fields, epoch_fields, plots): + super().__init__([(1, 'iteration'), (1, 'epoch')]) + self.results_path = results_path + + self.iteration_fields = self._fields_to_pairs(iteration_fields) + self.epoch_fields = self._fields_to_pairs(epoch_fields) + self.plots = plots + self.data = { + 'iterations': { + field: [] + for field in self.iteration_fields + [('iteration', 'last')] + }, + 'epochs': { + field: [] + for field in self.epoch_fields + [('iteration', 'last')] + } + } + + def register(self, trainer): + self.trainer = trainer + + def iteration(self, *args): + for (field, stat) in self.iteration_fields: + self.data['iterations'][field, stat].append( + self.trainer.stats[field][stat] + ) + + self.data['iterations']['iteration', 'last'].append( + self.trainer.iterations + ) + + def epoch(self, epoch_index): + for (field, stat) in self.epoch_fields: + self.data['epochs'][field, stat].append( + self.trainer.stats[field][stat] + ) + + self.data['epochs']['iteration', 'last'].append( + self.trainer.iterations + ) + + data_file_path = os.path.join(self.results_path, self.data_file_name) + with open(data_file_path, 'wb') as f: + pickle.dump(self.data, f) + + for (name, info) in self.plots.items(): + x_field = self._field_to_pair(info['x']) + + try: + y_fields = info['ys'] + except KeyError: + y_fields = [info['y']] + + labels = list(map( + lambda x: ' '.join(x) if type(x) is tuple else x, + y_fields + )) + y_fields = self._fields_to_pairs(y_fields) + + try: + formats = info['formats'] + except KeyError: + formats = [''] * len(y_fields) + + pyplot.gcf().clear() + + for (y_field, format, label) in zip(y_fields, formats, labels): + if y_field in self.iteration_fields: + part_name = 'iterations' + else: + part_name = 'epochs' + + xs = self.data[part_name][x_field] + ys = self.data[part_name][y_field] + + pyplot.plot(xs, ys, format, label=label) + + if 'log_y' in info and info['log_y']: + pyplot.yscale('log') + + pyplot.legend() + pyplot.savefig( + os.path.join(self.results_path, self.plot_pattern.format(name)) + ) + + @staticmethod + def _field_to_pair(field): + if type(field) is tuple: + return field + else: + return (field, 'last') + + @classmethod + def _fields_to_pairs(cls, fields): + return list(map(cls._field_to_pair, fields)) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..320fe95 --- /dev/null +++ b/utils.py @@ -0,0 +1,20 @@ +import torch +from torch import nn +import numpy as np + + +EPSILON = 1e-5 + +def linear_quantize(samples, q_levels): + samples = samples.clone() + samples -= samples.min(dim=-1)[0].expand_as(samples) + samples /= samples.max(dim=-1)[0].expand_as(samples) + samples *= q_levels - EPSILON + samples += EPSILON / 2 + return samples.long() + +def linear_dequantize(samples, q_levels): + return samples.float() / (q_levels / 2) - 1 + +def q_zero(q_levels): + return q_levels // 2 |
