diff options
| author | Piotr Kozakowski <kozak000@gmail.com> | 2017-05-11 17:49:12 +0200 |
|---|---|---|
| committer | Piotr Kozakowski <kozak000@gmail.com> | 2017-06-29 15:37:26 +0200 |
| commit | 2e308fe8e90276a892637be1bfa174e673ebf414 (patch) | |
| tree | 4ff187b37d16476cc936aba84184b8feca9c8612 /nn.py | |
| parent | 253860fdb0949f0eab6abff09369b0a1236b541a (diff) | |
Implement SampleRNN
Diffstat (limited to 'nn.py')
| -rw-r--r-- | nn.py | 70 |
1 files changed, 70 insertions, 0 deletions
@@ -0,0 +1,70 @@ +import torch +from torch import nn + +import math + + +class LearnedUpsampling1d(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + super().__init__() + + self.conv_t = nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=kernel_size, + bias=False + ) + + if bias: + self.bias = nn.Parameter( + torch.FloatTensor(out_channels, kernel_size) + ) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + self.conv_t.reset_parameters() + nn.init.constant(self.bias, 0) + + def forward(self, input): + (batch_size, _, length) = input.size() + (kernel_size,) = self.conv_t.kernel_size + bias = self.bias.unsqueeze(0).unsqueeze(2).expand( + batch_size, self.conv_t.out_channels, + length, kernel_size + ).contiguous().view( + batch_size, self.conv_t.out_channels, + length * kernel_size + ) + return self.conv_t(input) + bias + + +def lecun_uniform(tensor): + fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in') + nn.init.uniform(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in)) + + +def concat_init(tensor, inits): + try: + tensor = tensor.data + except AttributeError: + pass + + (length, fan_out) = tensor.size() + fan_in = length // len(inits) + + chunk = tensor.new(fan_in, fan_out) + for (i, init) in enumerate(inits): + init(chunk) + tensor[i * fan_in : (i + 1) * fan_in, :] = chunk + + +def sequence_nll_loss_bits(input, target, *args, **kwargs): + (_, _, n_classes) = input.size() + return nn.functional.nll_loss( + input.view(-1, n_classes), target.view(-1), *args, **kwargs + ) * math.log(math.e, 2) |
