summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--train.py22
1 files changed, 12 insertions, 10 deletions
diff --git a/train.py b/train.py
index 12d62ed..8bab87b 100644
--- a/train.py
+++ b/train.py
@@ -58,17 +58,17 @@ tag_params = [
'batch_size', 'dataset', 'val_frac', 'test_frac'
]
-def make_tag(params):
- def to_string(value):
- if isinstance(value, bool):
- return 'T' if value else 'F'
- elif isinstance(value, list):
- return ','.join(map(to_string, value))
- else:
- return str(value)
+def param_to_string(value):
+ if isinstance(value, bool):
+ return 'T' if value else 'F'
+ elif isinstance(value, list):
+ return ','.join(map(param_to_string, value))
+ else:
+ return str(value)
+def make_tag(params):
return '-'.join(
- key + ':' + to_string(params[key])
+ key + ':' + param_to_string(params[key])
for key in tag_params
if key not in default_params or params[key] != default_params[key]
)
@@ -148,7 +148,9 @@ def init_comet(params, trainer):
from comet_ml import Experiment
from trainer.plugins import CometPlugin
experiment = Experiment(api_key=params['comet_key'], log_code=False)
- hyperparams = {name: params[name] for name in tag_params}
+ hyperparams = {
+ name: param_to_string(params[name]) for name in tag_params
+ }
experiment.log_multiple_params(hyperparams)
trainer.register_plugin(CometPlugin(
experiment, [