diff options
Diffstat (limited to 'trainer/__init__.py')
| -rw-r--r-- | trainer/__init__.py | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/trainer/__init__.py b/trainer/__init__.py new file mode 100644 index 0000000..7e2ea18 --- /dev/null +++ b/trainer/__init__.py @@ -0,0 +1,100 @@ +import torch +from torch.autograd import Variable + +import heapq + + +# Based on torch.utils.trainer.Trainer code. +# Allows multiple inputs to the model, not all need to be Tensors. +class Trainer(object): + + def __init__(self, model, criterion, optimizer, dataset, cuda=False): + self.model = model + self.criterion = criterion + self.optimizer = optimizer + self.dataset = dataset + self.cuda = cuda + self.iterations = 0 + self.epochs = 0 + self.stats = {} + self.plugin_queues = { + 'iteration': [], + 'epoch': [], + 'batch': [], + 'update': [], + } + + def register_plugin(self, plugin): + plugin.register(self) + + intervals = plugin.trigger_interval + if not isinstance(intervals, list): + intervals = [intervals] + for (duration, unit) in intervals: + queue = self.plugin_queues[unit] + queue.append((duration, len(queue), plugin)) + + def call_plugins(self, queue_name, time, *args): + args = (time,) + args + queue = self.plugin_queues[queue_name] + if len(queue) == 0: + return + while queue[0][0] <= time: + plugin = queue[0][2] + getattr(plugin, queue_name)(*args) + for trigger in plugin.trigger_interval: + if trigger[1] == queue_name: + interval = trigger[0] + new_item = (time + interval, queue[0][1], plugin) + heapq.heappushpop(queue, new_item) + + def run(self, epochs=1): + for q in self.plugin_queues.values(): + heapq.heapify(q) + + for self.epochs in range(self.epochs + 1, self.epochs + epochs + 1): + self.train() + self.call_plugins('epoch', self.epochs) + + def train(self): + for (self.iterations, data) in \ + enumerate(self.dataset, self.iterations + 1): + batch_inputs = data[: -1] + batch_target = data[-1] + self.call_plugins( + 'batch', self.iterations, batch_inputs, batch_target + ) + + def wrap(input): + if torch.is_tensor(input): + input = Variable(input) + if self.cuda: + input = input.cuda() + return input + batch_inputs = list(map(wrap, batch_inputs)) + + batch_target = Variable(batch_target) + if self.cuda: + batch_target = batch_target.cuda() + + plugin_data = [None, None] + + def closure(): + batch_output = self.model(*batch_inputs) + + loss = self.criterion(batch_output, batch_target) + loss.backward() + + if plugin_data[0] is None: + plugin_data[0] = batch_output.data + plugin_data[1] = loss.data + + return loss + + self.optimizer.zero_grad() + self.optimizer.step(closure) + self.call_plugins( + 'iteration', self.iterations, batch_inputs, batch_target, + *plugin_data + ) + self.call_plugins('update', self.iterations, self.model) |
