summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPiotr Kozakowski <kozak000@gmail.com>2017-11-19 20:23:26 +0100
committerPiotr Kozakowski <kozak000@gmail.com>2017-11-19 20:23:26 +0100
commit4167442627b1414ff8fdc86528812b46168c656b (patch)
treef5020d2161762fad2db56f3f9ddcb3ad2deec553
parent61e935ff5a90c8c7b9a5a5f2f54d4ec8f9742dc0 (diff)
Add weight normalization
-rw-r--r--README.md2
-rw-r--r--model.py25
-rw-r--r--train.py8
3 files changed, 28 insertions, 7 deletions
diff --git a/README.md b/README.md
index 34e8530..4d39632 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@ A PyTorch implementation of [SampleRNN: An Unconditional End-to-End Neural Audio
![A visual representation of the SampleRNN architecture](http://deepsound.io/images/samplernn.png)
-It's based on the reference implementation in Theano: https://github.com/soroushmehr/sampleRNN_ICLR2017. Unlike the Theano version, our code allows training models with arbitrary number of tiers, whereas the original implementation allows maximum 3 tiers. However it doesn't have weight normalization and doesn't allow using LSTM units (only GRU). For more details and motivation behind rewriting this model to PyTorch, see our blog post: http://deepsound.io/samplernn_pytorch.html.
+It's based on the reference implementation in Theano: https://github.com/soroushmehr/sampleRNN_ICLR2017. Unlike the Theano version, our code allows training models with arbitrary number of tiers, whereas the original implementation allows maximum 3 tiers. However it doesn't allow using LSTM units (only GRU). For more details and motivation behind rewriting this model to PyTorch, see our blog post: http://deepsound.io/samplernn_pytorch.html.
## Dependencies
diff --git a/model.py b/model.py
index d2db1a9..d3716d1 100644
--- a/model.py
+++ b/model.py
@@ -10,7 +10,8 @@ import numpy as np
class SampleRNN(torch.nn.Module):
- def __init__(self, frame_sizes, n_rnn, dim, learn_h0, q_levels):
+ def __init__(self, frame_sizes, n_rnn, dim, learn_h0, q_levels,
+ weight_norm):
super().__init__()
self.dim = dim
@@ -19,14 +20,16 @@ class SampleRNN(torch.nn.Module):
ns_frame_samples = map(int, np.cumprod(frame_sizes))
self.frame_level_rnns = torch.nn.ModuleList([
FrameLevelRNN(
- frame_size, n_frame_samples, n_rnn, dim, learn_h0
+ frame_size, n_frame_samples, n_rnn, dim, learn_h0, weight_norm
)
for (frame_size, n_frame_samples) in zip(
frame_sizes, ns_frame_samples
)
])
- self.sample_level_mlp = SampleLevelMLP(frame_sizes[0], dim, q_levels)
+ self.sample_level_mlp = SampleLevelMLP(
+ frame_sizes[0], dim, q_levels, weight_norm
+ )
@property
def lookback(self):
@@ -36,7 +39,7 @@ class SampleRNN(torch.nn.Module):
class FrameLevelRNN(torch.nn.Module):
def __init__(self, frame_size, n_frame_samples, n_rnn, dim,
- learn_h0):
+ learn_h0, weight_norm):
super().__init__()
self.frame_size = frame_size
@@ -56,6 +59,8 @@ class FrameLevelRNN(torch.nn.Module):
)
init.kaiming_uniform(self.input_expand.weight)
init.constant(self.input_expand.bias, 0)
+ if weight_norm:
+ self.input_expand = torch.nn.utils.weight_norm(self.input_expand)
self.rnn = torch.nn.GRU(
input_size=dim,
@@ -85,6 +90,10 @@ class FrameLevelRNN(torch.nn.Module):
self.upsampling.conv_t.weight, -np.sqrt(6 / dim), np.sqrt(6 / dim)
)
init.constant(self.upsampling.bias, 0)
+ if weight_norm:
+ self.upsampling.conv_t = torch.nn.utils.weight_norm(
+ self.upsampling.conv_t
+ )
def forward(self, prev_samples, upper_tier_conditioning, hidden):
(batch_size, _, _) = prev_samples.size()
@@ -113,7 +122,7 @@ class FrameLevelRNN(torch.nn.Module):
class SampleLevelMLP(torch.nn.Module):
- def __init__(self, frame_size, dim, q_levels):
+ def __init__(self, frame_size, dim, q_levels, weight_norm):
super().__init__()
self.q_levels = q_levels
@@ -130,6 +139,8 @@ class SampleLevelMLP(torch.nn.Module):
bias=False
)
init.kaiming_uniform(self.input.weight)
+ if weight_norm:
+ self.input = torch.nn.utils.weight_norm(self.input)
self.hidden = torch.nn.Conv1d(
in_channels=dim,
@@ -138,6 +149,8 @@ class SampleLevelMLP(torch.nn.Module):
)
init.kaiming_uniform(self.hidden.weight)
init.constant(self.hidden.bias, 0)
+ if weight_norm:
+ self.hidden = torch.nn.utils.weight_norm(self.hidden)
self.output = torch.nn.Conv1d(
in_channels=dim,
@@ -146,6 +159,8 @@ class SampleLevelMLP(torch.nn.Module):
)
nn.lecun_uniform(self.output.weight)
init.constant(self.output.bias, 0)
+ if weight_norm:
+ self.output = torch.nn.utils.weight_norm(self.output)
def forward(self, prev_samples, upper_tier_conditioning):
(batch_size, _, _) = upper_tier_conditioning.size()
diff --git a/train.py b/train.py
index 8bab87b..e3061c9 100644
--- a/train.py
+++ b/train.py
@@ -35,6 +35,7 @@ default_params = {
'learn_h0': True,
'q_levels': 256,
'seq_len': 1024,
+ 'weight_norm': True,
'batch_size': 128,
'val_frac': 0.1,
'test_frac': 0.1,
@@ -175,7 +176,8 @@ def main(exp, frame_sizes, dataset, **params):
n_rnn=params['n_rnn'],
dim=params['dim'],
learn_h0=params['learn_h0'],
- q_levels=params['q_levels']
+ q_levels=params['q_levels'],
+ weight_norm=params['weight_norm']
)
predictor = Predictor(model)
if params['cuda']:
@@ -300,6 +302,10 @@ if __name__ == '__main__':
'--seq_len', type=int,
help='how many samples to include in each truncated BPTT pass'
)
+ parser.add_argument(
+ '--weight_norm', type=parse_bool,
+ help='whether to use weight normalization'
+ )
parser.add_argument('--batch_size', type=int, help='batch size')
parser.add_argument(
'--val_frac', type=float,