diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-05-14 23:36:49 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-05-14 23:36:49 +0200 |
| commit | 6917c6bd087daa5b3a140bdce0376cb1ea8d2cbc (patch) | |
| tree | a6a0179cfb37d2d18656425d199dddf20a59a7bc | |
| parent | bb9d6b928e8ea2ead95cdad14e1b68dd870d16c1 (diff) | |
okayyyyyyy
| -rwxr-xr-x | gen-prime.sh | 4 | ||||
| -rw-r--r-- | model.py | 15 |
2 files changed, 15 insertions, 4 deletions
diff --git a/gen-prime.sh b/gen-prime.sh index 19ef340..3f1ee11 100755 --- a/gen-prime.sh +++ b/gen-prime.sh @@ -54,8 +54,8 @@ function gen_prime_set () { # gen_prime $1 6 44100 'zero' # gen_prime $1 6 44100 'noise' # gen_prime $1 6 44100 'sin' 440 - gen_prime $1 1 100 'noise' 0 0 True - #gen_prime $1 6 100 'sin' 440 0 True + gen_prime $1 6 44100 'noise' 0 0 True + gen_prime $1 6 44100 'sin' 440 0 True } gen_prime_set jwcglassbeat @@ -329,7 +329,10 @@ class PrimedGenerator(Runner): q_max = q_levels print("_______-___-_---_-____") + print("_____________--_-_-_______") print("INITTTTTTTT {}".format(primer)) + if recursive: + print("RECURSIVE") print(sequences.shape) print("__________________--_-__--_________________") print("__-__________-_______________") @@ -403,6 +406,7 @@ class PrimedGenerator(Runner): ) sub_sequence = get_sub_sequence(i, bottom_frame_size) + # sub_sequence = sequences[:, i-bottom_frame_size : i] prev_samples = torch.autograd.Variable( sub_sequence, @@ -411,14 +415,21 @@ class PrimedGenerator(Runner): if self.cuda: prev_samples = prev_samples.cuda() + print("get upper tier conditioning.. {}".format(i % bottom_frame_size)) upper_tier_conditioning = \ frame_level_outputs[0][:, i % bottom_frame_size, :] \ .unsqueeze(1) + print(upper_tier_conditioning.shape) sample_dist = self.model.sample_level_mlp( prev_samples, upper_tier_conditioning ).squeeze(1).exp_().data - - out_sequences[:, i] = sample_dist.multinomial(1).squeeze(1) + print(sample_dist.shape) + multi = sample_dist.multinomial(1) + print(multi.shape) + pred = multi.squeeze(1) + print(pred.shape) + print(out_sequences.shape) + out_sequences[:, i] = pred torch.backends.cudnn.enabled = True |
