summaryrefslogtreecommitdiff
path: root/trainer/__init__.py
diff options
context:
space:
mode:
authorjules <jules@asdf.us>2018-03-20 23:35:18 +0100
committerjules <jules@asdf.us>2018-03-20 23:35:18 +0100
commitea6e6ee1040fa85f743ab50b699fbeb04d9a0522 (patch)
treee056f13c3ef89c5b6b8713a7f80c837b333129af /trainer/__init__.py
parent4167442627b1414ff8fdc86528812b46168c656b (diff)
scripts
Diffstat (limited to 'trainer/__init__.py')
-rw-r--r--trainer/__init__.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/trainer/__init__.py b/trainer/__init__.py
index 7e2ea18..1f39506 100644
--- a/trainer/__init__.py
+++ b/trainer/__init__.py
@@ -56,6 +56,15 @@ class Trainer(object):
self.train()
self.call_plugins('epoch', self.epochs)
+ def generate(self, epochs=1):
+ for q in self.plugin_queues.values():
+ heapq.heapify(q)
+
+ for self.epochs in range(self.epochs + 1, self.epochs + epochs + 1):
+ # self.train()
+ self.call_plugins('update', self.iterations, self.model)
+ self.call_plugins('epoch', self.epochs)
+
def train(self):
for (self.iterations, data) in \
enumerate(self.dataset, self.iterations + 1):