From eee8b5651d61268e0d025acda0a659ae88c951ce Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Mon, 14 May 2018 23:08:41 +0200 Subject: okayyyyyyy --- model.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'model.py') diff --git a/model.py b/model.py index c1f3cd3..bbbd6ea 100644 --- a/model.py +++ b/model.py @@ -438,7 +438,12 @@ class PrimedGenerator(Runner): prev_samples, upper_tier_conditioning ).squeeze(1).exp_().data print(sample_dist.shape) - out_sequences[:, i] = sample_dist.multinomial(1).squeeze(1) + multi = sample_dist.multinomial(1) + print(multi.shape) + pred = multi.squeeze(1) + print(pred.shape) + print(out_sequences.shape) + out_sequences[:, i] = pred torch.backends.cudnn.enabled = True -- cgit v1.2.3-70-g09d2