summaryrefslogtreecommitdiff
path: root/trainer/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'trainer/__init__.py')
-rw-r--r--trainer/__init__.py100
1 files changed, 100 insertions, 0 deletions
diff --git a/trainer/__init__.py b/trainer/__init__.py
new file mode 100644
index 0000000..7e2ea18
--- /dev/null
+++ b/trainer/__init__.py
@@ -0,0 +1,100 @@
+import torch
+from torch.autograd import Variable
+
+import heapq
+
+
+# Based on torch.utils.trainer.Trainer code.
+# Allows multiple inputs to the model, not all need to be Tensors.
+class Trainer(object):
+
+ def __init__(self, model, criterion, optimizer, dataset, cuda=False):
+ self.model = model
+ self.criterion = criterion
+ self.optimizer = optimizer
+ self.dataset = dataset
+ self.cuda = cuda
+ self.iterations = 0
+ self.epochs = 0
+ self.stats = {}
+ self.plugin_queues = {
+ 'iteration': [],
+ 'epoch': [],
+ 'batch': [],
+ 'update': [],
+ }
+
+ def register_plugin(self, plugin):
+ plugin.register(self)
+
+ intervals = plugin.trigger_interval
+ if not isinstance(intervals, list):
+ intervals = [intervals]
+ for (duration, unit) in intervals:
+ queue = self.plugin_queues[unit]
+ queue.append((duration, len(queue), plugin))
+
+ def call_plugins(self, queue_name, time, *args):
+ args = (time,) + args
+ queue = self.plugin_queues[queue_name]
+ if len(queue) == 0:
+ return
+ while queue[0][0] <= time:
+ plugin = queue[0][2]
+ getattr(plugin, queue_name)(*args)
+ for trigger in plugin.trigger_interval:
+ if trigger[1] == queue_name:
+ interval = trigger[0]
+ new_item = (time + interval, queue[0][1], plugin)
+ heapq.heappushpop(queue, new_item)
+
+ def run(self, epochs=1):
+ for q in self.plugin_queues.values():
+ heapq.heapify(q)
+
+ for self.epochs in range(self.epochs + 1, self.epochs + epochs + 1):
+ self.train()
+ self.call_plugins('epoch', self.epochs)
+
+ def train(self):
+ for (self.iterations, data) in \
+ enumerate(self.dataset, self.iterations + 1):
+ batch_inputs = data[: -1]
+ batch_target = data[-1]
+ self.call_plugins(
+ 'batch', self.iterations, batch_inputs, batch_target
+ )
+
+ def wrap(input):
+ if torch.is_tensor(input):
+ input = Variable(input)
+ if self.cuda:
+ input = input.cuda()
+ return input
+ batch_inputs = list(map(wrap, batch_inputs))
+
+ batch_target = Variable(batch_target)
+ if self.cuda:
+ batch_target = batch_target.cuda()
+
+ plugin_data = [None, None]
+
+ def closure():
+ batch_output = self.model(*batch_inputs)
+
+ loss = self.criterion(batch_output, batch_target)
+ loss.backward()
+
+ if plugin_data[0] is None:
+ plugin_data[0] = batch_output.data
+ plugin_data[1] = loss.data
+
+ return loss
+
+ self.optimizer.zero_grad()
+ self.optimizer.step(closure)
+ self.call_plugins(
+ 'iteration', self.iterations, batch_inputs, batch_target,
+ *plugin_data
+ )
+ self.call_plugins('update', self.iterations, self.model)