summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-05-14 23:27:16 +0200
committerJules Laplace <julescarbon@gmail.com>2018-05-14 23:27:16 +0200
commitba99237f1ab258e8f5eac0d01719e92751c9b2a3 (patch)
treec19bd3b4a7157beae2c2329f4ae1e90288480b09
parent410e8ed4f3b2e4debe61d510e23eb8401f238c13 (diff)
okayyyyyyy
-rwxr-xr-xgen-prime.sh2
-rw-r--r--model.py5
2 files changed, 4 insertions, 3 deletions
diff --git a/gen-prime.sh b/gen-prime.sh
index 28c9ddf..19ef340 100755
--- a/gen-prime.sh
+++ b/gen-prime.sh
@@ -12,7 +12,7 @@ function gen_prime () {
echo "___________________________________________________"
echo ">> generating $exp_name"
echo ""
- python generate.py \
+ CUDA_LAUNCH_BLOCKING=1 python generate.py \
--exp $exp_name --dataset $exp_name \
--frame_sizes 8 2 \
--n_rnn 2 --dim 1024 --q_levels 256 \
diff --git a/model.py b/model.py
index bbbd6ea..ff35baa 100644
--- a/model.py
+++ b/model.py
@@ -378,7 +378,7 @@ class PrimedGenerator(Runner):
ratio = i / (n-1)
a = sub_sequence_a[:, j].float() * (1-ratio)
b = sub_sequence_b[:, j].float() * ratio
- tmp_sub_sequence[:, j] = (a + b).long()
+ tmp_sub_sequence[:, j] = clamp(a + b, 0, q_levels).long()
return tmp_sub_sequence
@@ -420,7 +420,8 @@ class PrimedGenerator(Runner):
print("ran rnn")
print("at bottom frame")
- sub_sequence = get_sub_sequence(i, bottom_frame_size)
+ # sub_sequence = get_sub_sequence(i, bottom_frame_size)
+ sub_sequence = sequences[:, i-bottom_frame_size : i]
prev_samples = torch.autograd.Variable(
sub_sequence,