summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--model.py2
-rw-r--r--train.py1
2 files changed, 2 insertions, 1 deletions
diff --git a/model.py b/model.py
index d3716d1..567e533 100644
--- a/model.py
+++ b/model.py
@@ -242,7 +242,7 @@ class Generator(Runner):
super().__init__(model)
self.cuda = cuda
- def __call__(self, n_seqs, seq_len):
+ def __call__(self, n_seqs, seq_len, primer):
# generation doesn't work with CUDNN for some reason
torch.backends.cudnn.enabled = False
diff --git a/train.py b/train.py
index e3061c9..1f28396 100644
--- a/train.py
+++ b/train.py
@@ -52,6 +52,7 @@ default_params = {
'loss_smoothing': 0.99,
'cuda': True,
'comet_key': None
+ 'primer': 'zero',
}
tag_params = [