summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-05-14 23:36:49 +0200
committerJules Laplace <julescarbon@gmail.com>2018-05-14 23:36:49 +0200
commit6917c6bd087daa5b3a140bdce0376cb1ea8d2cbc (patch)
treea6a0179cfb37d2d18656425d199dddf20a59a7bc
parentbb9d6b928e8ea2ead95cdad14e1b68dd870d16c1 (diff)
okayyyyyyy
-rwxr-xr-xgen-prime.sh4
-rw-r--r--model.py15
2 files changed, 15 insertions, 4 deletions
diff --git a/gen-prime.sh b/gen-prime.sh
index 19ef340..3f1ee11 100755
--- a/gen-prime.sh
+++ b/gen-prime.sh
@@ -54,8 +54,8 @@ function gen_prime_set () {
# gen_prime $1 6 44100 'zero'
# gen_prime $1 6 44100 'noise'
# gen_prime $1 6 44100 'sin' 440
- gen_prime $1 1 100 'noise' 0 0 True
- #gen_prime $1 6 100 'sin' 440 0 True
+ gen_prime $1 6 44100 'noise' 0 0 True
+ gen_prime $1 6 44100 'sin' 440 0 True
}
gen_prime_set jwcglassbeat
diff --git a/model.py b/model.py
index 0a3dc54..12c0e05 100644
--- a/model.py
+++ b/model.py
@@ -329,7 +329,10 @@ class PrimedGenerator(Runner):
q_max = q_levels
print("_______-___-_---_-____")
+ print("_____________--_-_-_______")
print("INITTTTTTTT {}".format(primer))
+ if recursive:
+ print("RECURSIVE")
print(sequences.shape)
print("__________________--_-__--_________________")
print("__-__________-_______________")
@@ -403,6 +406,7 @@ class PrimedGenerator(Runner):
)
sub_sequence = get_sub_sequence(i, bottom_frame_size)
+ # sub_sequence = sequences[:, i-bottom_frame_size : i]
prev_samples = torch.autograd.Variable(
sub_sequence,
@@ -411,14 +415,21 @@ class PrimedGenerator(Runner):
if self.cuda:
prev_samples = prev_samples.cuda()
+ print("get upper tier conditioning.. {}".format(i % bottom_frame_size))
upper_tier_conditioning = \
frame_level_outputs[0][:, i % bottom_frame_size, :] \
.unsqueeze(1)
+ print(upper_tier_conditioning.shape)
sample_dist = self.model.sample_level_mlp(
prev_samples, upper_tier_conditioning
).squeeze(1).exp_().data
-
- out_sequences[:, i] = sample_dist.multinomial(1).squeeze(1)
+ print(sample_dist.shape)
+ multi = sample_dist.multinomial(1)
+ print(multi.shape)
+ pred = multi.squeeze(1)
+ print(pred.shape)
+ print(out_sequences.shape)
+ out_sequences[:, i] = pred
torch.backends.cudnn.enabled = True