summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-05-14 19:15:53 +0200
committerJules Laplace <julescarbon@gmail.com>2018-05-14 19:15:53 +0200
commit9766ef0f3a539be7ee68bb93918f25a3298afe39 (patch)
tree442e2d0d86f6475d7b3d1ca0e288748e2a7f8f4d
parent4167442627b1414ff8fdc86528812b46168c656b (diff)
stub in primer
-rw-r--r--model.py2
-rw-r--r--train.py1
2 files changed, 2 insertions, 1 deletions
diff --git a/model.py b/model.py
index d3716d1..567e533 100644
--- a/model.py
+++ b/model.py
@@ -242,7 +242,7 @@ class Generator(Runner):
super().__init__(model)
self.cuda = cuda
- def __call__(self, n_seqs, seq_len):
+ def __call__(self, n_seqs, seq_len, primer):
# generation doesn't work with CUDNN for some reason
torch.backends.cudnn.enabled = False
diff --git a/train.py b/train.py
index e3061c9..1f28396 100644
--- a/train.py
+++ b/train.py
@@ -52,6 +52,7 @@ default_params = {
'loss_smoothing': 0.99,
'cuda': True,
'comet_key': None
+ 'primer': 'zero',
}
tag_params = [