diff options
| author | Piotr Kozakowski <kozak000@gmail.com> | 2017-11-19 20:23:26 +0100 |
|---|---|---|
| committer | Piotr Kozakowski <kozak000@gmail.com> | 2017-11-19 20:23:26 +0100 |
| commit | 4167442627b1414ff8fdc86528812b46168c656b (patch) | |
| tree | f5020d2161762fad2db56f3f9ddcb3ad2deec553 | |
| parent | 61e935ff5a90c8c7b9a5a5f2f54d4ec8f9742dc0 (diff) | |
Add weight normalization
| -rw-r--r-- | README.md | 2 | ||||
| -rw-r--r-- | model.py | 25 | ||||
| -rw-r--r-- | train.py | 8 |
3 files changed, 28 insertions, 7 deletions
@@ -4,7 +4,7 @@ A PyTorch implementation of [SampleRNN: An Unconditional End-to-End Neural Audio  -It's based on the reference implementation in Theano: https://github.com/soroushmehr/sampleRNN_ICLR2017. Unlike the Theano version, our code allows training models with arbitrary number of tiers, whereas the original implementation allows maximum 3 tiers. However it doesn't have weight normalization and doesn't allow using LSTM units (only GRU). For more details and motivation behind rewriting this model to PyTorch, see our blog post: http://deepsound.io/samplernn_pytorch.html. +It's based on the reference implementation in Theano: https://github.com/soroushmehr/sampleRNN_ICLR2017. Unlike the Theano version, our code allows training models with arbitrary number of tiers, whereas the original implementation allows maximum 3 tiers. However it doesn't allow using LSTM units (only GRU). For more details and motivation behind rewriting this model to PyTorch, see our blog post: http://deepsound.io/samplernn_pytorch.html. ## Dependencies @@ -10,7 +10,8 @@ import numpy as np class SampleRNN(torch.nn.Module): - def __init__(self, frame_sizes, n_rnn, dim, learn_h0, q_levels): + def __init__(self, frame_sizes, n_rnn, dim, learn_h0, q_levels, + weight_norm): super().__init__() self.dim = dim @@ -19,14 +20,16 @@ class SampleRNN(torch.nn.Module): ns_frame_samples = map(int, np.cumprod(frame_sizes)) self.frame_level_rnns = torch.nn.ModuleList([ FrameLevelRNN( - frame_size, n_frame_samples, n_rnn, dim, learn_h0 + frame_size, n_frame_samples, n_rnn, dim, learn_h0, weight_norm ) for (frame_size, n_frame_samples) in zip( frame_sizes, ns_frame_samples ) ]) - self.sample_level_mlp = SampleLevelMLP(frame_sizes[0], dim, q_levels) + self.sample_level_mlp = SampleLevelMLP( + frame_sizes[0], dim, q_levels, weight_norm + ) @property def lookback(self): @@ -36,7 +39,7 @@ class SampleRNN(torch.nn.Module): class FrameLevelRNN(torch.nn.Module): def __init__(self, frame_size, n_frame_samples, n_rnn, dim, - learn_h0): + learn_h0, weight_norm): super().__init__() self.frame_size = frame_size @@ -56,6 +59,8 @@ class FrameLevelRNN(torch.nn.Module): ) init.kaiming_uniform(self.input_expand.weight) init.constant(self.input_expand.bias, 0) + if weight_norm: + self.input_expand = torch.nn.utils.weight_norm(self.input_expand) self.rnn = torch.nn.GRU( input_size=dim, @@ -85,6 +90,10 @@ class FrameLevelRNN(torch.nn.Module): self.upsampling.conv_t.weight, -np.sqrt(6 / dim), np.sqrt(6 / dim) ) init.constant(self.upsampling.bias, 0) + if weight_norm: + self.upsampling.conv_t = torch.nn.utils.weight_norm( + self.upsampling.conv_t + ) def forward(self, prev_samples, upper_tier_conditioning, hidden): (batch_size, _, _) = prev_samples.size() @@ -113,7 +122,7 @@ class FrameLevelRNN(torch.nn.Module): class SampleLevelMLP(torch.nn.Module): - def __init__(self, frame_size, dim, q_levels): + def __init__(self, frame_size, dim, q_levels, weight_norm): super().__init__() self.q_levels = q_levels @@ -130,6 +139,8 @@ class SampleLevelMLP(torch.nn.Module): bias=False ) init.kaiming_uniform(self.input.weight) + if weight_norm: + self.input = torch.nn.utils.weight_norm(self.input) self.hidden = torch.nn.Conv1d( in_channels=dim, @@ -138,6 +149,8 @@ class SampleLevelMLP(torch.nn.Module): ) init.kaiming_uniform(self.hidden.weight) init.constant(self.hidden.bias, 0) + if weight_norm: + self.hidden = torch.nn.utils.weight_norm(self.hidden) self.output = torch.nn.Conv1d( in_channels=dim, @@ -146,6 +159,8 @@ class SampleLevelMLP(torch.nn.Module): ) nn.lecun_uniform(self.output.weight) init.constant(self.output.bias, 0) + if weight_norm: + self.output = torch.nn.utils.weight_norm(self.output) def forward(self, prev_samples, upper_tier_conditioning): (batch_size, _, _) = upper_tier_conditioning.size() @@ -35,6 +35,7 @@ default_params = { 'learn_h0': True, 'q_levels': 256, 'seq_len': 1024, + 'weight_norm': True, 'batch_size': 128, 'val_frac': 0.1, 'test_frac': 0.1, @@ -175,7 +176,8 @@ def main(exp, frame_sizes, dataset, **params): n_rnn=params['n_rnn'], dim=params['dim'], learn_h0=params['learn_h0'], - q_levels=params['q_levels'] + q_levels=params['q_levels'], + weight_norm=params['weight_norm'] ) predictor = Predictor(model) if params['cuda']: @@ -300,6 +302,10 @@ if __name__ == '__main__': '--seq_len', type=int, help='how many samples to include in each truncated BPTT pass' ) + parser.add_argument( + '--weight_norm', type=parse_bool, + help='whether to use weight normalization' + ) parser.add_argument('--batch_size', type=int, help='batch size') parser.add_argument( '--val_frac', type=float, |
