From 7de430fe12bc6ebdab054884dddda20463431181 Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Mon, 14 May 2018 20:02:19 +0200 Subject: generator function.. --- trainer/plugins.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'trainer') diff --git a/trainer/plugins.py b/trainer/plugins.py index dc3b24a..562355e 100644 --- a/trainer/plugins.py +++ b/trainer/plugins.py @@ -1,7 +1,7 @@ import matplotlib matplotlib.use('Agg') -from model import Generator +from model import Generator, PrimedGenerator import torch from torch.autograd import Variable @@ -152,7 +152,10 @@ class GeneratorPlugin(Plugin): self.primer = primer def register(self, trainer): - self.generate = Generator(trainer.model.model, trainer.cuda) + if self.primer == "": + self.generate = Generator(trainer.model.model, trainer.cuda) + else: + self.generate = PrimedGenerator(trainer.model.model, trainer.cuda) def epoch(self, epoch_index): samples = self.generate(self.n_samples, self.sample_length, self.primer) \ -- cgit v1.2.3-70-g09d2