From 2d7ad8119c5962a320e6ccc39407e417d088c525 Mon Sep 17 00:00:00 2001 From: Piotr Kozakowski Date: Fri, 3 Nov 2017 19:50:11 +0100 Subject: Add CometML integration --- README.md | 2 ++ train.py | 30 +++++++++++++++++++++++++++++- trainer/plugins.py | 20 ++++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5604d32..34e8530 100644 --- a/README.md +++ b/README.md @@ -30,3 +30,5 @@ python train.py --exp TEST --frame_sizes 16 4 --n_rnn 2 --dataset piano ``` The results - training log, loss plots, model checkpoints and generated samples will be saved in `results/`. + +We also have an option to monitor the metrics using [CometML](https://www.comet.ml/). To use it, just pass your API key as `--comet_key` parameter to `train.py`. diff --git a/train.py b/train.py index 934af91..12d62ed 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,9 @@ +# CometML needs to be imported first. +try: + import comet_ml +except ImportError: + pass + from model import SampleRNN, Predictor from optim import gradient_clipping from nn import sequence_nll_loss_bits @@ -43,7 +49,8 @@ default_params = { 'n_samples': 1, 'sample_length': 80000, 'loss_smoothing': 0.99, - 'cuda': True + 'cuda': True, + 'comet_key': None } tag_params = [ @@ -136,6 +143,21 @@ def make_data_loader(overlap_len, params): ) return data_loader +def init_comet(params, trainer): + if params['comet_key'] is not None: + from comet_ml import Experiment + from trainer.plugins import CometPlugin + experiment = Experiment(api_key=params['comet_key'], log_code=False) + hyperparams = {name: params[name] for name in tag_params} + experiment.log_multiple_params(hyperparams) + trainer.register_plugin(CometPlugin( + experiment, [ + ('training_loss', 'epoch_mean'), + 'validation_loss', + 'test_loss' + ] + )) + def main(exp, frame_sizes, dataset, **params): params = dict( default_params, @@ -226,6 +248,9 @@ def main(exp, frame_sizes, dataset, **params): } } )) + + init_comet(params, trainer) + trainer.run(params['epoch_limit']) @@ -318,6 +343,9 @@ if __name__ == '__main__': '--cuda', type=parse_bool, help='whether to use CUDA' ) + parser.add_argument( + '--comet_key', help='comet.ml API key' + ) parser.set_defaults(**default_params) diff --git a/trainer/plugins.py b/trainer/plugins.py index 552804a..f8c299b 100644 --- a/trainer/plugins.py +++ b/trainer/plugins.py @@ -265,3 +265,23 @@ class StatsPlugin(Plugin): @classmethod def _fields_to_pairs(cls, fields): return list(map(cls._field_to_pair, fields)) + + +class CometPlugin(Plugin): + + def __init__(self, experiment, fields): + super().__init__([(1, 'epoch')]) + + self.experiment = experiment + self.fields = [ + field if type(field) is tuple else (field, 'last') + for field in fields + ] + + def register(self, trainer): + self.trainer = trainer + + def epoch(self, epoch_index): + for (field, stat) in self.fields: + self.experiment.log_metric(field, self.trainer.stats[field][stat]) + self.experiment.log_epoch_end(epoch_index) -- cgit v1.2.3-70-g09d2