summaryrefslogtreecommitdiff
path: root/trainer
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-05-14 20:02:19 +0200
committerJules Laplace <julescarbon@gmail.com>2018-05-14 20:02:19 +0200
commit7de430fe12bc6ebdab054884dddda20463431181 (patch)
tree50a63f7af38788bde1e334a02d347fa75ff111ea /trainer
parent88eecd62b75cc032752aa10121d376cc7bca418b (diff)
generator function..
Diffstat (limited to 'trainer')
-rw-r--r--trainer/plugins.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/trainer/plugins.py b/trainer/plugins.py
index dc3b24a..562355e 100644
--- a/trainer/plugins.py
+++ b/trainer/plugins.py
@@ -1,7 +1,7 @@
import matplotlib
matplotlib.use('Agg')
-from model import Generator
+from model import Generator, PrimedGenerator
import torch
from torch.autograd import Variable
@@ -152,7 +152,10 @@ class GeneratorPlugin(Plugin):
self.primer = primer
def register(self, trainer):
- self.generate = Generator(trainer.model.model, trainer.cuda)
+ if self.primer == "":
+ self.generate = Generator(trainer.model.model, trainer.cuda)
+ else:
+ self.generate = PrimedGenerator(trainer.model.model, trainer.cuda)
def epoch(self, epoch_index):
samples = self.generate(self.n_samples, self.sample_length, self.primer) \