diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-05-14 20:02:19 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-05-14 20:02:19 +0200 |
| commit | 7de430fe12bc6ebdab054884dddda20463431181 (patch) | |
| tree | 50a63f7af38788bde1e334a02d347fa75ff111ea /trainer | |
| parent | 88eecd62b75cc032752aa10121d376cc7bca418b (diff) | |
generator function..
Diffstat (limited to 'trainer')
| -rw-r--r-- | trainer/plugins.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/trainer/plugins.py b/trainer/plugins.py index dc3b24a..562355e 100644 --- a/trainer/plugins.py +++ b/trainer/plugins.py @@ -1,7 +1,7 @@ import matplotlib matplotlib.use('Agg') -from model import Generator +from model import Generator, PrimedGenerator import torch from torch.autograd import Variable @@ -152,7 +152,10 @@ class GeneratorPlugin(Plugin): self.primer = primer def register(self, trainer): - self.generate = Generator(trainer.model.model, trainer.cuda) + if self.primer == "": + self.generate = Generator(trainer.model.model, trainer.cuda) + else: + self.generate = PrimedGenerator(trainer.model.model, trainer.cuda) def epoch(self, epoch_index): samples = self.generate(self.n_samples, self.sample_length, self.primer) \ |
