summaryrefslogtreecommitdiff
path: root/nn.py
diff options
context:
space:
mode:
authorPiotr Kozakowski <kozak000@gmail.com>2017-05-11 17:49:12 +0200
committerPiotr Kozakowski <kozak000@gmail.com>2017-06-29 15:37:26 +0200
commit2e308fe8e90276a892637be1bfa174e673ebf414 (patch)
tree4ff187b37d16476cc936aba84184b8feca9c8612 /nn.py
parent253860fdb0949f0eab6abff09369b0a1236b541a (diff)
Implement SampleRNN
Diffstat (limited to 'nn.py')
-rw-r--r--nn.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/nn.py b/nn.py
new file mode 100644
index 0000000..db47414
--- /dev/null
+++ b/nn.py
@@ -0,0 +1,70 @@
+import torch
+from torch import nn
+
+import math
+
+
+class LearnedUpsampling1d(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, bias=True):
+ super().__init__()
+
+ self.conv_t = nn.ConvTranspose1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=kernel_size,
+ bias=False
+ )
+
+ if bias:
+ self.bias = nn.Parameter(
+ torch.FloatTensor(out_channels, kernel_size)
+ )
+ else:
+ self.register_parameter('bias', None)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ self.conv_t.reset_parameters()
+ nn.init.constant(self.bias, 0)
+
+ def forward(self, input):
+ (batch_size, _, length) = input.size()
+ (kernel_size,) = self.conv_t.kernel_size
+ bias = self.bias.unsqueeze(0).unsqueeze(2).expand(
+ batch_size, self.conv_t.out_channels,
+ length, kernel_size
+ ).contiguous().view(
+ batch_size, self.conv_t.out_channels,
+ length * kernel_size
+ )
+ return self.conv_t(input) + bias
+
+
+def lecun_uniform(tensor):
+ fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
+ nn.init.uniform(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
+
+
+def concat_init(tensor, inits):
+ try:
+ tensor = tensor.data
+ except AttributeError:
+ pass
+
+ (length, fan_out) = tensor.size()
+ fan_in = length // len(inits)
+
+ chunk = tensor.new(fan_in, fan_out)
+ for (i, init) in enumerate(inits):
+ init(chunk)
+ tensor[i * fan_in : (i + 1) * fan_in, :] = chunk
+
+
+def sequence_nll_loss_bits(input, target, *args, **kwargs):
+ (_, _, n_classes) = input.size()
+ return nn.functional.nll_loss(
+ input.view(-1, n_classes), target.view(-1), *args, **kwargs
+ ) * math.log(math.e, 2)