summaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
authorPiotr Kozakowski <kozak000@gmail.com>2017-05-11 17:49:12 +0200
committerPiotr Kozakowski <kozak000@gmail.com>2017-06-29 15:37:26 +0200
commit2e308fe8e90276a892637be1bfa174e673ebf414 (patch)
tree4ff187b37d16476cc936aba84184b8feca9c8612 /model.py
parent253860fdb0949f0eab6abff09369b0a1236b541a (diff)
Implement SampleRNN
Diffstat (limited to 'model.py')
-rw-r--r--model.py286
1 files changed, 286 insertions, 0 deletions
diff --git a/model.py b/model.py
new file mode 100644
index 0000000..d2db1a9
--- /dev/null
+++ b/model.py
@@ -0,0 +1,286 @@
+import nn
+import utils
+
+import torch
+from torch.nn import functional as F
+from torch.nn import init
+
+import numpy as np
+
+
+class SampleRNN(torch.nn.Module):
+
+ def __init__(self, frame_sizes, n_rnn, dim, learn_h0, q_levels):
+ super().__init__()
+
+ self.dim = dim
+ self.q_levels = q_levels
+
+ ns_frame_samples = map(int, np.cumprod(frame_sizes))
+ self.frame_level_rnns = torch.nn.ModuleList([
+ FrameLevelRNN(
+ frame_size, n_frame_samples, n_rnn, dim, learn_h0
+ )
+ for (frame_size, n_frame_samples) in zip(
+ frame_sizes, ns_frame_samples
+ )
+ ])
+
+ self.sample_level_mlp = SampleLevelMLP(frame_sizes[0], dim, q_levels)
+
+ @property
+ def lookback(self):
+ return self.frame_level_rnns[-1].n_frame_samples
+
+
+class FrameLevelRNN(torch.nn.Module):
+
+ def __init__(self, frame_size, n_frame_samples, n_rnn, dim,
+ learn_h0):
+ super().__init__()
+
+ self.frame_size = frame_size
+ self.n_frame_samples = n_frame_samples
+ self.dim = dim
+
+ h0 = torch.zeros(n_rnn, dim)
+ if learn_h0:
+ self.h0 = torch.nn.Parameter(h0)
+ else:
+ self.register_buffer('h0', torch.autograd.Variable(h0))
+
+ self.input_expand = torch.nn.Conv1d(
+ in_channels=n_frame_samples,
+ out_channels=dim,
+ kernel_size=1
+ )
+ init.kaiming_uniform(self.input_expand.weight)
+ init.constant(self.input_expand.bias, 0)
+
+ self.rnn = torch.nn.GRU(
+ input_size=dim,
+ hidden_size=dim,
+ num_layers=n_rnn,
+ batch_first=True
+ )
+ for i in range(n_rnn):
+ nn.concat_init(
+ getattr(self.rnn, 'weight_ih_l{}'.format(i)),
+ [nn.lecun_uniform, nn.lecun_uniform, nn.lecun_uniform]
+ )
+ init.constant(getattr(self.rnn, 'bias_ih_l{}'.format(i)), 0)
+
+ nn.concat_init(
+ getattr(self.rnn, 'weight_hh_l{}'.format(i)),
+ [nn.lecun_uniform, nn.lecun_uniform, init.orthogonal]
+ )
+ init.constant(getattr(self.rnn, 'bias_hh_l{}'.format(i)), 0)
+
+ self.upsampling = nn.LearnedUpsampling1d(
+ in_channels=dim,
+ out_channels=dim,
+ kernel_size=frame_size
+ )
+ init.uniform(
+ self.upsampling.conv_t.weight, -np.sqrt(6 / dim), np.sqrt(6 / dim)
+ )
+ init.constant(self.upsampling.bias, 0)
+
+ def forward(self, prev_samples, upper_tier_conditioning, hidden):
+ (batch_size, _, _) = prev_samples.size()
+
+ input = self.input_expand(
+ prev_samples.permute(0, 2, 1)
+ ).permute(0, 2, 1)
+ if upper_tier_conditioning is not None:
+ input += upper_tier_conditioning
+
+ reset = hidden is None
+
+ if hidden is None:
+ (n_rnn, _) = self.h0.size()
+ hidden = self.h0.unsqueeze(1) \
+ .expand(n_rnn, batch_size, self.dim) \
+ .contiguous()
+
+ (output, hidden) = self.rnn(input, hidden)
+
+ output = self.upsampling(
+ output.permute(0, 2, 1)
+ ).permute(0, 2, 1)
+ return (output, hidden)
+
+
+class SampleLevelMLP(torch.nn.Module):
+
+ def __init__(self, frame_size, dim, q_levels):
+ super().__init__()
+
+ self.q_levels = q_levels
+
+ self.embedding = torch.nn.Embedding(
+ self.q_levels,
+ self.q_levels
+ )
+
+ self.input = torch.nn.Conv1d(
+ in_channels=q_levels,
+ out_channels=dim,
+ kernel_size=frame_size,
+ bias=False
+ )
+ init.kaiming_uniform(self.input.weight)
+
+ self.hidden = torch.nn.Conv1d(
+ in_channels=dim,
+ out_channels=dim,
+ kernel_size=1
+ )
+ init.kaiming_uniform(self.hidden.weight)
+ init.constant(self.hidden.bias, 0)
+
+ self.output = torch.nn.Conv1d(
+ in_channels=dim,
+ out_channels=q_levels,
+ kernel_size=1
+ )
+ nn.lecun_uniform(self.output.weight)
+ init.constant(self.output.bias, 0)
+
+ def forward(self, prev_samples, upper_tier_conditioning):
+ (batch_size, _, _) = upper_tier_conditioning.size()
+
+ prev_samples = self.embedding(
+ prev_samples.contiguous().view(-1)
+ ).view(
+ batch_size, -1, self.q_levels
+ )
+
+ prev_samples = prev_samples.permute(0, 2, 1)
+ upper_tier_conditioning = upper_tier_conditioning.permute(0, 2, 1)
+
+ x = F.relu(self.input(prev_samples) + upper_tier_conditioning)
+ x = F.relu(self.hidden(x))
+ x = self.output(x).permute(0, 2, 1).contiguous()
+
+ return F.log_softmax(x.view(-1, self.q_levels)) \
+ .view(batch_size, -1, self.q_levels)
+
+
+class Runner:
+
+ def __init__(self, model):
+ super().__init__()
+ self.model = model
+ self.reset_hidden_states()
+
+ def reset_hidden_states(self):
+ self.hidden_states = {rnn: None for rnn in self.model.frame_level_rnns}
+
+ def run_rnn(self, rnn, prev_samples, upper_tier_conditioning):
+ (output, new_hidden) = rnn(
+ prev_samples, upper_tier_conditioning, self.hidden_states[rnn]
+ )
+ self.hidden_states[rnn] = new_hidden.detach()
+ return output
+
+
+class Predictor(Runner, torch.nn.Module):
+
+ def __init__(self, model):
+ super().__init__(model)
+
+ def forward(self, input_sequences, reset):
+ if reset:
+ self.reset_hidden_states()
+
+ (batch_size, _) = input_sequences.size()
+
+ upper_tier_conditioning = None
+ for rnn in reversed(self.model.frame_level_rnns):
+ from_index = self.model.lookback - rnn.n_frame_samples
+ to_index = -rnn.n_frame_samples + 1
+ prev_samples = 2 * utils.linear_dequantize(
+ input_sequences[:, from_index : to_index],
+ self.model.q_levels
+ )
+ prev_samples = prev_samples.contiguous().view(
+ batch_size, -1, rnn.n_frame_samples
+ )
+
+ upper_tier_conditioning = self.run_rnn(
+ rnn, prev_samples, upper_tier_conditioning
+ )
+
+ bottom_frame_size = self.model.frame_level_rnns[0].frame_size
+ mlp_input_sequences = input_sequences \
+ [:, self.model.lookback - bottom_frame_size :]
+
+ return self.model.sample_level_mlp(
+ mlp_input_sequences, upper_tier_conditioning
+ )
+
+
+class Generator(Runner):
+
+ def __init__(self, model, cuda=False):
+ super().__init__(model)
+ self.cuda = cuda
+
+ def __call__(self, n_seqs, seq_len):
+ # generation doesn't work with CUDNN for some reason
+ torch.backends.cudnn.enabled = False
+
+ self.reset_hidden_states()
+
+ bottom_frame_size = self.model.frame_level_rnns[0].n_frame_samples
+ sequences = torch.LongTensor(n_seqs, self.model.lookback + seq_len) \
+ .fill_(utils.q_zero(self.model.q_levels))
+ frame_level_outputs = [None for _ in self.model.frame_level_rnns]
+
+ for i in range(self.model.lookback, self.model.lookback + seq_len):
+ for (tier_index, rnn) in \
+ reversed(list(enumerate(self.model.frame_level_rnns))):
+ if i % rnn.n_frame_samples != 0:
+ continue
+
+ prev_samples = torch.autograd.Variable(
+ 2 * utils.linear_dequantize(
+ sequences[:, i - rnn.n_frame_samples : i],
+ self.model.q_levels
+ ).unsqueeze(1),
+ volatile=True
+ )
+ if self.cuda:
+ prev_samples = prev_samples.cuda()
+
+ if tier_index == len(self.model.frame_level_rnns) - 1:
+ upper_tier_conditioning = None
+ else:
+ frame_index = (i // rnn.n_frame_samples) % \
+ self.model.frame_level_rnns[tier_index + 1].frame_size
+ upper_tier_conditioning = \
+ frame_level_outputs[tier_index + 1][:, frame_index, :] \
+ .unsqueeze(1)
+
+ frame_level_outputs[tier_index] = self.run_rnn(
+ rnn, prev_samples, upper_tier_conditioning
+ )
+
+ prev_samples = torch.autograd.Variable(
+ sequences[:, i - bottom_frame_size : i],
+ volatile=True
+ )
+ if self.cuda:
+ prev_samples = prev_samples.cuda()
+ upper_tier_conditioning = \
+ frame_level_outputs[0][:, i % bottom_frame_size, :] \
+ .unsqueeze(1)
+ sample_dist = self.model.sample_level_mlp(
+ prev_samples, upper_tier_conditioning
+ ).squeeze(1).exp_().data
+ sequences[:, i] = sample_dist.multinomial(1).squeeze(1)
+
+ torch.backends.cudnn.enabled = True
+
+ return sequences[:, self.model.lookback :]