summaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'model.py')
-rw-r--r--model.py25
1 files changed, 20 insertions, 5 deletions
diff --git a/model.py b/model.py
index d2db1a9..d3716d1 100644
--- a/model.py
+++ b/model.py
@@ -10,7 +10,8 @@ import numpy as np
class SampleRNN(torch.nn.Module):
- def __init__(self, frame_sizes, n_rnn, dim, learn_h0, q_levels):
+ def __init__(self, frame_sizes, n_rnn, dim, learn_h0, q_levels,
+ weight_norm):
super().__init__()
self.dim = dim
@@ -19,14 +20,16 @@ class SampleRNN(torch.nn.Module):
ns_frame_samples = map(int, np.cumprod(frame_sizes))
self.frame_level_rnns = torch.nn.ModuleList([
FrameLevelRNN(
- frame_size, n_frame_samples, n_rnn, dim, learn_h0
+ frame_size, n_frame_samples, n_rnn, dim, learn_h0, weight_norm
)
for (frame_size, n_frame_samples) in zip(
frame_sizes, ns_frame_samples
)
])
- self.sample_level_mlp = SampleLevelMLP(frame_sizes[0], dim, q_levels)
+ self.sample_level_mlp = SampleLevelMLP(
+ frame_sizes[0], dim, q_levels, weight_norm
+ )
@property
def lookback(self):
@@ -36,7 +39,7 @@ class SampleRNN(torch.nn.Module):
class FrameLevelRNN(torch.nn.Module):
def __init__(self, frame_size, n_frame_samples, n_rnn, dim,
- learn_h0):
+ learn_h0, weight_norm):
super().__init__()
self.frame_size = frame_size
@@ -56,6 +59,8 @@ class FrameLevelRNN(torch.nn.Module):
)
init.kaiming_uniform(self.input_expand.weight)
init.constant(self.input_expand.bias, 0)
+ if weight_norm:
+ self.input_expand = torch.nn.utils.weight_norm(self.input_expand)
self.rnn = torch.nn.GRU(
input_size=dim,
@@ -85,6 +90,10 @@ class FrameLevelRNN(torch.nn.Module):
self.upsampling.conv_t.weight, -np.sqrt(6 / dim), np.sqrt(6 / dim)
)
init.constant(self.upsampling.bias, 0)
+ if weight_norm:
+ self.upsampling.conv_t = torch.nn.utils.weight_norm(
+ self.upsampling.conv_t
+ )
def forward(self, prev_samples, upper_tier_conditioning, hidden):
(batch_size, _, _) = prev_samples.size()
@@ -113,7 +122,7 @@ class FrameLevelRNN(torch.nn.Module):
class SampleLevelMLP(torch.nn.Module):
- def __init__(self, frame_size, dim, q_levels):
+ def __init__(self, frame_size, dim, q_levels, weight_norm):
super().__init__()
self.q_levels = q_levels
@@ -130,6 +139,8 @@ class SampleLevelMLP(torch.nn.Module):
bias=False
)
init.kaiming_uniform(self.input.weight)
+ if weight_norm:
+ self.input = torch.nn.utils.weight_norm(self.input)
self.hidden = torch.nn.Conv1d(
in_channels=dim,
@@ -138,6 +149,8 @@ class SampleLevelMLP(torch.nn.Module):
)
init.kaiming_uniform(self.hidden.weight)
init.constant(self.hidden.bias, 0)
+ if weight_norm:
+ self.hidden = torch.nn.utils.weight_norm(self.hidden)
self.output = torch.nn.Conv1d(
in_channels=dim,
@@ -146,6 +159,8 @@ class SampleLevelMLP(torch.nn.Module):
)
nn.lecun_uniform(self.output.weight)
init.constant(self.output.bias, 0)
+ if weight_norm:
+ self.output = torch.nn.utils.weight_norm(self.output)
def forward(self, prev_samples, upper_tier_conditioning):
(batch_size, _, _) = upper_tier_conditioning.size()