summaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-05-14 20:38:13 +0200
committerJules Laplace <julescarbon@gmail.com>2018-05-14 20:38:13 +0200
commit914ffba685b594c11bbfb8559a44f54042e40d83 (patch)
tree777948436d295dce892ffc817a2bddb7f7ec176b /model.py
parent45e252f2fc33992b52dc34251c0ba31970b92ee5 (diff)
get pl param printing being nice
Diffstat (limited to 'model.py')
-rw-r--r--model.py13
1 files changed, 9 insertions, 4 deletions
diff --git a/model.py b/model.py
index 35317f3..b07b6b8 100644
--- a/model.py
+++ b/model.py
@@ -310,7 +310,7 @@ class PrimedGenerator(Runner):
super().__init__(model)
self.cuda = cuda
- def __call__(self, n_seqs, seq_len, primer):
+ def __call__(self, n_seqs, seq_len, primer, prime_param_a, prime_param_b, recursive):
# generation doesn't work with CUDNN for some reason
torch.backends.cudnn.enabled = False
@@ -321,10 +321,14 @@ class PrimedGenerator(Runner):
sequences = torch.LongTensor(n_seqs, n_samples) # 64-bit int
frame_level_outputs = [None for _ in self.model.frame_level_rnns]
+ if recursive:
+ out_sequences = sequences
+ else:
+ out_sequences = torch.LongTensor(n_seqs, n_samples).fill_(utils.q_zero(self.model.q_levels))
+
q_levels = self.model.q_levels
q_min = 0
q_max = q_levels
- primer_freq = 440
print("_______-___-_---_-____")
print("_____________--_-_-_______")
@@ -337,6 +341,7 @@ class PrimedGenerator(Runner):
for i in xrange(n_samples):
x[:, i] = random.triangular(q_min, q_max)
def sin(x):
+ primer_freq = prime_param_a
for i in xrange(n_samples):
x[:, i] = (math.sin(i/44100 * primer_freq) + 1) / 2 * (q_max - q_min) + q_min
@@ -387,8 +392,8 @@ class PrimedGenerator(Runner):
sample_dist = self.model.sample_level_mlp(
prev_samples, upper_tier_conditioning
).squeeze(1).exp_().data
- sequences[:, i] = sample_dist.multinomial(1).squeeze(1)
+ out_sequences[:, i] = sample_dist.multinomial(1).squeeze(1)
torch.backends.cudnn.enabled = True
- return sequences[:, self.model.lookback :]
+ return out_sequences[:, self.model.lookback :]