diff options
| author | Piotr Kozakowski <kozak000@gmail.com> | 2017-11-03 19:50:11 +0100 |
|---|---|---|
| committer | Piotr Kozakowski <kozak000@gmail.com> | 2017-11-03 20:28:01 +0100 |
| commit | 2d7ad8119c5962a320e6ccc39407e417d088c525 (patch) | |
| tree | 1a253f9fe574eb88a21b3b6c3e78ca545e1f1430 /trainer | |
| parent | 7912b6e2f8610e22d9f7e6ceb3e16ed0828b46bf (diff) | |
Add CometML integration
Diffstat (limited to 'trainer')
| -rw-r--r-- | trainer/plugins.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/trainer/plugins.py b/trainer/plugins.py index 552804a..f8c299b 100644 --- a/trainer/plugins.py +++ b/trainer/plugins.py @@ -265,3 +265,23 @@ class StatsPlugin(Plugin): @classmethod def _fields_to_pairs(cls, fields): return list(map(cls._field_to_pair, fields)) + + +class CometPlugin(Plugin): + + def __init__(self, experiment, fields): + super().__init__([(1, 'epoch')]) + + self.experiment = experiment + self.fields = [ + field if type(field) is tuple else (field, 'last') + for field in fields + ] + + def register(self, trainer): + self.trainer = trainer + + def epoch(self, epoch_index): + for (field, stat) in self.fields: + self.experiment.log_metric(field, self.trainer.stats[field][stat]) + self.experiment.log_epoch_end(epoch_index) |
