summaryrefslogtreecommitdiff
path: root/trainer
diff options
context:
space:
mode:
Diffstat (limited to 'trainer')
-rw-r--r--trainer/__init__.py100
-rw-r--r--trainer/plugins.py267
2 files changed, 367 insertions, 0 deletions
diff --git a/trainer/__init__.py b/trainer/__init__.py
new file mode 100644
index 0000000..7e2ea18
--- /dev/null
+++ b/trainer/__init__.py
@@ -0,0 +1,100 @@
+import torch
+from torch.autograd import Variable
+
+import heapq
+
+
+# Based on torch.utils.trainer.Trainer code.
+# Allows multiple inputs to the model, not all need to be Tensors.
+class Trainer(object):
+
+ def __init__(self, model, criterion, optimizer, dataset, cuda=False):
+ self.model = model
+ self.criterion = criterion
+ self.optimizer = optimizer
+ self.dataset = dataset
+ self.cuda = cuda
+ self.iterations = 0
+ self.epochs = 0
+ self.stats = {}
+ self.plugin_queues = {
+ 'iteration': [],
+ 'epoch': [],
+ 'batch': [],
+ 'update': [],
+ }
+
+ def register_plugin(self, plugin):
+ plugin.register(self)
+
+ intervals = plugin.trigger_interval
+ if not isinstance(intervals, list):
+ intervals = [intervals]
+ for (duration, unit) in intervals:
+ queue = self.plugin_queues[unit]
+ queue.append((duration, len(queue), plugin))
+
+ def call_plugins(self, queue_name, time, *args):
+ args = (time,) + args
+ queue = self.plugin_queues[queue_name]
+ if len(queue) == 0:
+ return
+ while queue[0][0] <= time:
+ plugin = queue[0][2]
+ getattr(plugin, queue_name)(*args)
+ for trigger in plugin.trigger_interval:
+ if trigger[1] == queue_name:
+ interval = trigger[0]
+ new_item = (time + interval, queue[0][1], plugin)
+ heapq.heappushpop(queue, new_item)
+
+ def run(self, epochs=1):
+ for q in self.plugin_queues.values():
+ heapq.heapify(q)
+
+ for self.epochs in range(self.epochs + 1, self.epochs + epochs + 1):
+ self.train()
+ self.call_plugins('epoch', self.epochs)
+
+ def train(self):
+ for (self.iterations, data) in \
+ enumerate(self.dataset, self.iterations + 1):
+ batch_inputs = data[: -1]
+ batch_target = data[-1]
+ self.call_plugins(
+ 'batch', self.iterations, batch_inputs, batch_target
+ )
+
+ def wrap(input):
+ if torch.is_tensor(input):
+ input = Variable(input)
+ if self.cuda:
+ input = input.cuda()
+ return input
+ batch_inputs = list(map(wrap, batch_inputs))
+
+ batch_target = Variable(batch_target)
+ if self.cuda:
+ batch_target = batch_target.cuda()
+
+ plugin_data = [None, None]
+
+ def closure():
+ batch_output = self.model(*batch_inputs)
+
+ loss = self.criterion(batch_output, batch_target)
+ loss.backward()
+
+ if plugin_data[0] is None:
+ plugin_data[0] = batch_output.data
+ plugin_data[1] = loss.data
+
+ return loss
+
+ self.optimizer.zero_grad()
+ self.optimizer.step(closure)
+ self.call_plugins(
+ 'iteration', self.iterations, batch_inputs, batch_target,
+ *plugin_data
+ )
+ self.call_plugins('update', self.iterations, self.model)
diff --git a/trainer/plugins.py b/trainer/plugins.py
new file mode 100644
index 0000000..552804a
--- /dev/null
+++ b/trainer/plugins.py
@@ -0,0 +1,267 @@
+import matplotlib
+matplotlib.use('Agg')
+
+from model import Generator
+
+import torch
+from torch.autograd import Variable
+from torch.utils.trainer.plugins.plugin import Plugin
+from torch.utils.trainer.plugins.monitor import Monitor
+from torch.utils.trainer.plugins import LossMonitor
+
+from librosa.output import write_wav
+from matplotlib import pyplot
+
+from glob import glob
+import os
+import pickle
+import time
+
+
+class TrainingLossMonitor(LossMonitor):
+
+ stat_name = 'training_loss'
+
+
+class ValidationPlugin(Plugin):
+
+ def __init__(self, val_dataset, test_dataset):
+ super().__init__([(1, 'epoch')])
+ self.val_dataset = val_dataset
+ self.test_dataset = test_dataset
+
+ def register(self, trainer):
+ self.trainer = trainer
+ val_stats = self.trainer.stats.setdefault('validation_loss', {})
+ val_stats['log_epoch_fields'] = ['{last:.4f}']
+ test_stats = self.trainer.stats.setdefault('test_loss', {})
+ test_stats['log_epoch_fields'] = ['{last:.4f}']
+
+ def epoch(self, idx):
+ self.trainer.model.eval()
+
+ val_stats = self.trainer.stats.setdefault('validation_loss', {})
+ val_stats['last'] = self._evaluate(self.val_dataset)
+ test_stats = self.trainer.stats.setdefault('test_loss', {})
+ test_stats['last'] = self._evaluate(self.test_dataset)
+
+ self.trainer.model.train()
+
+ def _evaluate(self, dataset):
+ loss_sum = 0
+ n_examples = 0
+ for data in dataset:
+ batch_inputs = data[: -1]
+ batch_target = data[-1]
+ batch_size = batch_target.size()[0]
+
+ def wrap(input):
+ if torch.is_tensor(input):
+ input = Variable(input, volatile=True)
+ if self.trainer.cuda:
+ input = input.cuda()
+ return input
+ batch_inputs = list(map(wrap, batch_inputs))
+
+ batch_target = Variable(batch_target, volatile=True)
+ if self.trainer.cuda:
+ batch_target = batch_target.cuda()
+
+ batch_output = self.trainer.model(*batch_inputs)
+ loss_sum += self.trainer.criterion(batch_output, batch_target) \
+ .data[0] * batch_size
+
+ n_examples += batch_size
+
+ return loss_sum / n_examples
+
+
+class AbsoluteTimeMonitor(Monitor):
+
+ stat_name = 'time'
+
+ def __init__(self, *args, **kwargs):
+ kwargs.setdefault('unit', 's')
+ kwargs.setdefault('precision', 0)
+ kwargs.setdefault('running_average', False)
+ kwargs.setdefault('epoch_average', False)
+ super(AbsoluteTimeMonitor, self).__init__(*args, **kwargs)
+ self.start_time = None
+
+ def _get_value(self, *args):
+ if self.start_time is None:
+ self.start_time = time.time()
+ return time.time() - self.start_time
+
+
+class SaverPlugin(Plugin):
+
+ last_pattern = 'ep{}-it{}'
+ best_pattern = 'best-ep{}-it{}'
+
+ def __init__(self, checkpoints_path, keep_old_checkpoints):
+ super().__init__([(1, 'epoch')])
+ self.checkpoints_path = checkpoints_path
+ self.keep_old_checkpoints = keep_old_checkpoints
+ self._best_val_loss = float('+inf')
+
+ def register(self, trainer):
+ self.trainer = trainer
+
+ def epoch(self, epoch_index):
+ if not self.keep_old_checkpoints:
+ self._clear(self.last_pattern.format('*', '*'))
+ torch.save(
+ self.trainer.model.state_dict(),
+ os.path.join(
+ self.checkpoints_path,
+ self.last_pattern.format(epoch_index, self.trainer.iterations)
+ )
+ )
+
+ cur_val_loss = self.trainer.stats['validation_loss']['last']
+ if cur_val_loss < self._best_val_loss:
+ self._clear(self.best_pattern.format('*', '*'))
+ torch.save(
+ self.trainer.model.state_dict(),
+ os.path.join(
+ self.checkpoints_path,
+ self.best_pattern.format(
+ epoch_index, self.trainer.iterations
+ )
+ )
+ )
+ self._best_val_loss = cur_val_loss
+
+ def _clear(self, pattern):
+ pattern = os.path.join(self.checkpoints_path, pattern)
+ for file_name in glob(pattern):
+ os.remove(file_name)
+
+
+class GeneratorPlugin(Plugin):
+
+ pattern = 'ep{}-s{}.wav'
+
+ def __init__(self, samples_path, n_samples, sample_length, sample_rate):
+ super().__init__([(1, 'epoch')])
+ self.samples_path = samples_path
+ self.n_samples = n_samples
+ self.sample_length = sample_length
+ self.sample_rate = sample_rate
+
+ def register(self, trainer):
+ self.generate = Generator(trainer.model.model, trainer.cuda)
+
+ def epoch(self, epoch_index):
+ samples = self.generate(self.n_samples, self.sample_length) \
+ .cpu().float().numpy()
+ for i in range(self.n_samples):
+ write_wav(
+ os.path.join(
+ self.samples_path, self.pattern.format(epoch_index, i + 1)
+ ),
+ samples[i, :], sr=self.sample_rate, norm=True
+ )
+
+
+class StatsPlugin(Plugin):
+
+ data_file_name = 'stats.pkl'
+ plot_pattern = '{}.svg'
+
+ def __init__(self, results_path, iteration_fields, epoch_fields, plots):
+ super().__init__([(1, 'iteration'), (1, 'epoch')])
+ self.results_path = results_path
+
+ self.iteration_fields = self._fields_to_pairs(iteration_fields)
+ self.epoch_fields = self._fields_to_pairs(epoch_fields)
+ self.plots = plots
+ self.data = {
+ 'iterations': {
+ field: []
+ for field in self.iteration_fields + [('iteration', 'last')]
+ },
+ 'epochs': {
+ field: []
+ for field in self.epoch_fields + [('iteration', 'last')]
+ }
+ }
+
+ def register(self, trainer):
+ self.trainer = trainer
+
+ def iteration(self, *args):
+ for (field, stat) in self.iteration_fields:
+ self.data['iterations'][field, stat].append(
+ self.trainer.stats[field][stat]
+ )
+
+ self.data['iterations']['iteration', 'last'].append(
+ self.trainer.iterations
+ )
+
+ def epoch(self, epoch_index):
+ for (field, stat) in self.epoch_fields:
+ self.data['epochs'][field, stat].append(
+ self.trainer.stats[field][stat]
+ )
+
+ self.data['epochs']['iteration', 'last'].append(
+ self.trainer.iterations
+ )
+
+ data_file_path = os.path.join(self.results_path, self.data_file_name)
+ with open(data_file_path, 'wb') as f:
+ pickle.dump(self.data, f)
+
+ for (name, info) in self.plots.items():
+ x_field = self._field_to_pair(info['x'])
+
+ try:
+ y_fields = info['ys']
+ except KeyError:
+ y_fields = [info['y']]
+
+ labels = list(map(
+ lambda x: ' '.join(x) if type(x) is tuple else x,
+ y_fields
+ ))
+ y_fields = self._fields_to_pairs(y_fields)
+
+ try:
+ formats = info['formats']
+ except KeyError:
+ formats = [''] * len(y_fields)
+
+ pyplot.gcf().clear()
+
+ for (y_field, format, label) in zip(y_fields, formats, labels):
+ if y_field in self.iteration_fields:
+ part_name = 'iterations'
+ else:
+ part_name = 'epochs'
+
+ xs = self.data[part_name][x_field]
+ ys = self.data[part_name][y_field]
+
+ pyplot.plot(xs, ys, format, label=label)
+
+ if 'log_y' in info and info['log_y']:
+ pyplot.yscale('log')
+
+ pyplot.legend()
+ pyplot.savefig(
+ os.path.join(self.results_path, self.plot_pattern.format(name))
+ )
+
+ @staticmethod
+ def _field_to_pair(field):
+ if type(field) is tuple:
+ return field
+ else:
+ return (field, 'last')
+
+ @classmethod
+ def _fields_to_pairs(cls, fields):
+ return list(map(cls._field_to_pair, fields))