summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py30
1 files changed, 29 insertions, 1 deletions
diff --git a/train.py b/train.py
index 934af91..12d62ed 100644
--- a/train.py
+++ b/train.py
@@ -1,3 +1,9 @@
+# CometML needs to be imported first.
+try:
+ import comet_ml
+except ImportError:
+ pass
+
from model import SampleRNN, Predictor
from optim import gradient_clipping
from nn import sequence_nll_loss_bits
@@ -43,7 +49,8 @@ default_params = {
'n_samples': 1,
'sample_length': 80000,
'loss_smoothing': 0.99,
- 'cuda': True
+ 'cuda': True,
+ 'comet_key': None
}
tag_params = [
@@ -136,6 +143,21 @@ def make_data_loader(overlap_len, params):
)
return data_loader
+def init_comet(params, trainer):
+ if params['comet_key'] is not None:
+ from comet_ml import Experiment
+ from trainer.plugins import CometPlugin
+ experiment = Experiment(api_key=params['comet_key'], log_code=False)
+ hyperparams = {name: params[name] for name in tag_params}
+ experiment.log_multiple_params(hyperparams)
+ trainer.register_plugin(CometPlugin(
+ experiment, [
+ ('training_loss', 'epoch_mean'),
+ 'validation_loss',
+ 'test_loss'
+ ]
+ ))
+
def main(exp, frame_sizes, dataset, **params):
params = dict(
default_params,
@@ -226,6 +248,9 @@ def main(exp, frame_sizes, dataset, **params):
}
}
))
+
+ init_comet(params, trainer)
+
trainer.run(params['epoch_limit'])
@@ -318,6 +343,9 @@ if __name__ == '__main__':
'--cuda', type=parse_bool,
help='whether to use CUDA'
)
+ parser.add_argument(
+ '--comet_key', help='comet.ml API key'
+ )
parser.set_defaults(**default_params)