From 4167442627b1414ff8fdc86528812b46168c656b Mon Sep 17 00:00:00 2001 From: Piotr Kozakowski Date: Sun, 19 Nov 2017 20:23:26 +0100 Subject: Add weight normalization --- model.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) (limited to 'model.py') diff --git a/model.py b/model.py index d2db1a9..d3716d1 100644 --- a/model.py +++ b/model.py @@ -10,7 +10,8 @@ import numpy as np class SampleRNN(torch.nn.Module): - def __init__(self, frame_sizes, n_rnn, dim, learn_h0, q_levels): + def __init__(self, frame_sizes, n_rnn, dim, learn_h0, q_levels, + weight_norm): super().__init__() self.dim = dim @@ -19,14 +20,16 @@ class SampleRNN(torch.nn.Module): ns_frame_samples = map(int, np.cumprod(frame_sizes)) self.frame_level_rnns = torch.nn.ModuleList([ FrameLevelRNN( - frame_size, n_frame_samples, n_rnn, dim, learn_h0 + frame_size, n_frame_samples, n_rnn, dim, learn_h0, weight_norm ) for (frame_size, n_frame_samples) in zip( frame_sizes, ns_frame_samples ) ]) - self.sample_level_mlp = SampleLevelMLP(frame_sizes[0], dim, q_levels) + self.sample_level_mlp = SampleLevelMLP( + frame_sizes[0], dim, q_levels, weight_norm + ) @property def lookback(self): @@ -36,7 +39,7 @@ class SampleRNN(torch.nn.Module): class FrameLevelRNN(torch.nn.Module): def __init__(self, frame_size, n_frame_samples, n_rnn, dim, - learn_h0): + learn_h0, weight_norm): super().__init__() self.frame_size = frame_size @@ -56,6 +59,8 @@ class FrameLevelRNN(torch.nn.Module): ) init.kaiming_uniform(self.input_expand.weight) init.constant(self.input_expand.bias, 0) + if weight_norm: + self.input_expand = torch.nn.utils.weight_norm(self.input_expand) self.rnn = torch.nn.GRU( input_size=dim, @@ -85,6 +90,10 @@ class FrameLevelRNN(torch.nn.Module): self.upsampling.conv_t.weight, -np.sqrt(6 / dim), np.sqrt(6 / dim) ) init.constant(self.upsampling.bias, 0) + if weight_norm: + self.upsampling.conv_t = torch.nn.utils.weight_norm( + self.upsampling.conv_t + ) def forward(self, prev_samples, upper_tier_conditioning, hidden): (batch_size, _, _) = prev_samples.size() @@ -113,7 +122,7 @@ class FrameLevelRNN(torch.nn.Module): class SampleLevelMLP(torch.nn.Module): - def __init__(self, frame_size, dim, q_levels): + def __init__(self, frame_size, dim, q_levels, weight_norm): super().__init__() self.q_levels = q_levels @@ -130,6 +139,8 @@ class SampleLevelMLP(torch.nn.Module): bias=False ) init.kaiming_uniform(self.input.weight) + if weight_norm: + self.input = torch.nn.utils.weight_norm(self.input) self.hidden = torch.nn.Conv1d( in_channels=dim, @@ -138,6 +149,8 @@ class SampleLevelMLP(torch.nn.Module): ) init.kaiming_uniform(self.hidden.weight) init.constant(self.hidden.bias, 0) + if weight_norm: + self.hidden = torch.nn.utils.weight_norm(self.hidden) self.output = torch.nn.Conv1d( in_channels=dim, @@ -146,6 +159,8 @@ class SampleLevelMLP(torch.nn.Module): ) nn.lecun_uniform(self.output.weight) init.constant(self.output.bias, 0) + if weight_norm: + self.output = torch.nn.utils.weight_norm(self.output) def forward(self, prev_samples, upper_tier_conditioning): (batch_size, _, _) = upper_tier_conditioning.size() -- cgit v1.2.3-70-g09d2