summaryrefslogtreecommitdiff
path: root/trainer/plugins.py
diff options
context:
space:
mode:
Diffstat (limited to 'trainer/plugins.py')
-rw-r--r--trainer/plugins.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/trainer/plugins.py b/trainer/plugins.py
index 552804a..f8c299b 100644
--- a/trainer/plugins.py
+++ b/trainer/plugins.py
@@ -265,3 +265,23 @@ class StatsPlugin(Plugin):
@classmethod
def _fields_to_pairs(cls, fields):
return list(map(cls._field_to_pair, fields))
+
+
+class CometPlugin(Plugin):
+
+ def __init__(self, experiment, fields):
+ super().__init__([(1, 'epoch')])
+
+ self.experiment = experiment
+ self.fields = [
+ field if type(field) is tuple else (field, 'last')
+ for field in fields
+ ]
+
+ def register(self, trainer):
+ self.trainer = trainer
+
+ def epoch(self, epoch_index):
+ for (field, stat) in self.fields:
+ self.experiment.log_metric(field, self.trainer.stats[field][stat])
+ self.experiment.log_epoch_end(epoch_index)