summaryrefslogtreecommitdiff
path: root/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'model.py')
-rw-r--r--model.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/model.py b/model.py
index c1f3cd3..bbbd6ea 100644
--- a/model.py
+++ b/model.py
@@ -438,7 +438,12 @@ class PrimedGenerator(Runner):
prev_samples, upper_tier_conditioning
).squeeze(1).exp_().data
print(sample_dist.shape)
- out_sequences[:, i] = sample_dist.multinomial(1).squeeze(1)
+ multi = sample_dist.multinomial(1)
+ print(multi.shape)
+ pred = multi.squeeze(1)
+ print(pred.shape)
+ print(out_sequences.shape)
+ out_sequences[:, i] = pred
torch.backends.cudnn.enabled = True