summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-05-14 22:41:11 +0200
committerJules Laplace <julescarbon@gmail.com>2018-05-14 22:41:11 +0200
commit91f53bd920b9e883caceb910d6d263039ce367a7 (patch)
tree937c37fdb43b1da7c4dbf5cc25fa3e47099cb257
parent6684cfa723839990cb3ca165dd5d5a46be6e23e3 (diff)
whyyyy
-rwxr-xr-xgen-prime.sh4
-rw-r--r--model.py9
2 files changed, 6 insertions, 7 deletions
diff --git a/gen-prime.sh b/gen-prime.sh
index b4e5614..2fb638c 100755
--- a/gen-prime.sh
+++ b/gen-prime.sh
@@ -12,7 +12,7 @@ function gen_prime () {
echo "___________________________________________________"
echo ">> generating $exp_name"
echo ""
- CUDA_LAUNCH_BLOCKING=1 python generate.py \
+ python generate.py \
--exp $exp_name --dataset $exp_name \
--frame_sizes 8 2 \
--n_rnn 2 --dim 1024 --q_levels 256 \
@@ -27,8 +27,8 @@ function gen_prime () {
--primer_a $primer_a \
--primer_b $primer_b \
--recursive $recursive \
- --cuda False \
--resume True
+ # --cuda False \
tag="${primer}" # _${sample_length}"
diff --git a/model.py b/model.py
index aa52088..4b83ad7 100644
--- a/model.py
+++ b/model.py
@@ -367,15 +367,14 @@ class PrimedGenerator(Runner):
for j in range(n):
ratio = i / (n-1)
- for k in range(n_seqs):
- a = (1-ratio) * sub_sequence_a[k, j]
- b = ratio * sub_sequence_b[k, j]
- tmp_sub_sequence[k, j] = a + b
+ a = sub_sequence_a[:, j].float() * (1-ratio)
+ b = sub_sequence_b[:, j].float() * ratio
+ tmp_sub_sequence[:, j] = (a + b).long()
return tmp_sub_sequence
for i in range(self.model.lookback, self.model.lookback + seq_len):
- if i > 0 and (i % 1000) == 0:
+ if (i % 1000) == 0:
print("{}...".format(i))
for (tier_index, rnn) in \