1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
import torch
from torch.autograd import Variable
import heapq
# Based on torch.utils.trainer.Trainer code.
# Allows multiple inputs to the model, not all need to be Tensors.
class Trainer(object):
def __init__(self, model, criterion, optimizer, dataset, cuda=False):
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.dataset = dataset
self.cuda = cuda
self.iterations = 0
self.epochs = 0
self.stats = {}
self.plugin_queues = {
'iteration': [],
'epoch': [],
'batch': [],
'update': [],
}
def register_plugin(self, plugin):
plugin.register(self)
intervals = plugin.trigger_interval
if not isinstance(intervals, list):
intervals = [intervals]
for (duration, unit) in intervals:
queue = self.plugin_queues[unit]
queue.append((duration, len(queue), plugin))
def call_plugins(self, queue_name, time, *args):
args = (time,) + args
queue = self.plugin_queues[queue_name]
if len(queue) == 0:
return
while queue[0][0] <= time:
plugin = queue[0][2]
getattr(plugin, queue_name)(*args)
for trigger in plugin.trigger_interval:
if trigger[1] == queue_name:
interval = trigger[0]
new_item = (time + interval, queue[0][1], plugin)
heapq.heappushpop(queue, new_item)
def run(self, epochs=1):
for q in self.plugin_queues.values():
heapq.heapify(q)
for self.epochs in range(self.epochs + 1, self.epochs + epochs + 1):
self.train()
self.call_plugins('epoch', self.epochs)
def train(self):
for (self.iterations, data) in \
enumerate(self.dataset, self.iterations + 1):
batch_inputs = data[: -1]
batch_target = data[-1]
self.call_plugins(
'batch', self.iterations, batch_inputs, batch_target
)
def wrap(input):
if torch.is_tensor(input):
input = Variable(input)
if self.cuda:
input = input.cuda()
return input
batch_inputs = list(map(wrap, batch_inputs))
batch_target = Variable(batch_target)
if self.cuda:
batch_target = batch_target.cuda()
plugin_data = [None, None]
def closure():
batch_output = self.model(*batch_inputs)
loss = self.criterion(batch_output, batch_target)
loss.backward()
if plugin_data[0] is None:
plugin_data[0] = batch_output.data
plugin_data[1] = loss.data
return loss
self.optimizer.zero_grad()
self.optimizer.step(closure)
self.call_plugins(
'iteration', self.iterations, batch_inputs, batch_target,
*plugin_data
)
self.call_plugins('update', self.iterations, self.model)
|