diff options
Diffstat (limited to 'trainer')
| -rw-r--r-- | trainer/plugins.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/trainer/plugins.py b/trainer/plugins.py index dc3b24a..562355e 100644 --- a/trainer/plugins.py +++ b/trainer/plugins.py @@ -1,7 +1,7 @@ import matplotlib matplotlib.use('Agg') -from model import Generator +from model import Generator, PrimedGenerator import torch from torch.autograd import Variable @@ -152,7 +152,10 @@ class GeneratorPlugin(Plugin): self.primer = primer def register(self, trainer): - self.generate = Generator(trainer.model.model, trainer.cuda) + if self.primer == "": + self.generate = Generator(trainer.model.model, trainer.cuda) + else: + self.generate = PrimedGenerator(trainer.model.model, trainer.cuda) def epoch(self, epoch_index): samples = self.generate(self.n_samples, self.sample_length, self.primer) \ |
