summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPiotr Kozakowski <kozak000@gmail.com>2017-11-03 19:50:11 +0100
committerPiotr Kozakowski <kozak000@gmail.com>2017-11-03 20:28:01 +0100
commit2d7ad8119c5962a320e6ccc39407e417d088c525 (patch)
tree1a253f9fe574eb88a21b3b6c3e78ca545e1f1430
parent7912b6e2f8610e22d9f7e6ceb3e16ed0828b46bf (diff)
Add CometML integration
-rw-r--r--README.md2
-rw-r--r--train.py30
-rw-r--r--trainer/plugins.py20
3 files changed, 51 insertions, 1 deletions
diff --git a/README.md b/README.md
index 5604d32..34e8530 100644
--- a/README.md
+++ b/README.md
@@ -30,3 +30,5 @@ python train.py --exp TEST --frame_sizes 16 4 --n_rnn 2 --dataset piano
```
The results - training log, loss plots, model checkpoints and generated samples will be saved in `results/`.
+
+We also have an option to monitor the metrics using [CometML](https://www.comet.ml/). To use it, just pass your API key as `--comet_key` parameter to `train.py`.
diff --git a/train.py b/train.py
index 934af91..12d62ed 100644
--- a/train.py
+++ b/train.py
@@ -1,3 +1,9 @@
+# CometML needs to be imported first.
+try:
+ import comet_ml
+except ImportError:
+ pass
+
from model import SampleRNN, Predictor
from optim import gradient_clipping
from nn import sequence_nll_loss_bits
@@ -43,7 +49,8 @@ default_params = {
'n_samples': 1,
'sample_length': 80000,
'loss_smoothing': 0.99,
- 'cuda': True
+ 'cuda': True,
+ 'comet_key': None
}
tag_params = [
@@ -136,6 +143,21 @@ def make_data_loader(overlap_len, params):
)
return data_loader
+def init_comet(params, trainer):
+ if params['comet_key'] is not None:
+ from comet_ml import Experiment
+ from trainer.plugins import CometPlugin
+ experiment = Experiment(api_key=params['comet_key'], log_code=False)
+ hyperparams = {name: params[name] for name in tag_params}
+ experiment.log_multiple_params(hyperparams)
+ trainer.register_plugin(CometPlugin(
+ experiment, [
+ ('training_loss', 'epoch_mean'),
+ 'validation_loss',
+ 'test_loss'
+ ]
+ ))
+
def main(exp, frame_sizes, dataset, **params):
params = dict(
default_params,
@@ -226,6 +248,9 @@ def main(exp, frame_sizes, dataset, **params):
}
}
))
+
+ init_comet(params, trainer)
+
trainer.run(params['epoch_limit'])
@@ -318,6 +343,9 @@ if __name__ == '__main__':
'--cuda', type=parse_bool,
help='whether to use CUDA'
)
+ parser.add_argument(
+ '--comet_key', help='comet.ml API key'
+ )
parser.set_defaults(**default_params)
diff --git a/trainer/plugins.py b/trainer/plugins.py
index 552804a..f8c299b 100644
--- a/trainer/plugins.py
+++ b/trainer/plugins.py
@@ -265,3 +265,23 @@ class StatsPlugin(Plugin):
@classmethod
def _fields_to_pairs(cls, fields):
return list(map(cls._field_to_pair, fields))
+
+
+class CometPlugin(Plugin):
+
+ def __init__(self, experiment, fields):
+ super().__init__([(1, 'epoch')])
+
+ self.experiment = experiment
+ self.fields = [
+ field if type(field) is tuple else (field, 'last')
+ for field in fields
+ ]
+
+ def register(self, trainer):
+ self.trainer = trainer
+
+ def epoch(self, epoch_index):
+ for (field, stat) in self.fields:
+ self.experiment.log_metric(field, self.trainer.stats[field][stat])
+ self.experiment.log_epoch_end(epoch_index)