summaryrefslogtreecommitdiff
path: root/trainer
diff options
context:
space:
mode:
authorPiotr Kozakowski <kozak000@gmail.com>2017-11-03 19:50:11 +0100
committerPiotr Kozakowski <kozak000@gmail.com>2017-11-03 20:28:01 +0100
commit2d7ad8119c5962a320e6ccc39407e417d088c525 (patch)
tree1a253f9fe574eb88a21b3b6c3e78ca545e1f1430 /trainer
parent7912b6e2f8610e22d9f7e6ceb3e16ed0828b46bf (diff)
Add CometML integration
Diffstat (limited to 'trainer')
-rw-r--r--trainer/plugins.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/trainer/plugins.py b/trainer/plugins.py
index 552804a..f8c299b 100644
--- a/trainer/plugins.py
+++ b/trainer/plugins.py
@@ -265,3 +265,23 @@ class StatsPlugin(Plugin):
@classmethod
def _fields_to_pairs(cls, fields):
return list(map(cls._field_to_pair, fields))
+
+
+class CometPlugin(Plugin):
+
+ def __init__(self, experiment, fields):
+ super().__init__([(1, 'epoch')])
+
+ self.experiment = experiment
+ self.fields = [
+ field if type(field) is tuple else (field, 'last')
+ for field in fields
+ ]
+
+ def register(self, trainer):
+ self.trainer = trainer
+
+ def epoch(self, epoch_index):
+ for (field, stat) in self.fields:
+ self.experiment.log_metric(field, self.trainer.stats[field][stat])
+ self.experiment.log_epoch_end(epoch_index)