summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorPiotr Kozakowski <kozak000@gmail.com>2017-11-03 18:40:07 +0100
committerPiotr Kozakowski <kozak000@gmail.com>2017-11-03 18:40:07 +0100
commit1f4136ee723a457e4f4f8bcc1e5cc328881b00d3 (patch)
tree0d0131b82fa19076b1b6d9d385a74afca6594d94 /train.py
parentbfbadd23cd75a9ab3f837ec890f4c1f78388d85f (diff)
Add option to run without CUDA
Diffstat (limited to 'train.py')
-rw-r--r--train.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/train.py b/train.py
index c47cd4b..934af91 100644
--- a/train.py
+++ b/train.py
@@ -42,7 +42,8 @@ default_params = {
'sample_rate': 16000,
'n_samples': 1,
'sample_length': 80000,
- 'loss_smoothing': 0.99
+ 'loss_smoothing': 0.99,
+ 'cuda': True
}
tag_params = [
@@ -151,8 +152,11 @@ def main(exp, frame_sizes, dataset, **params):
dim=params['dim'],
learn_h0=params['learn_h0'],
q_levels=params['q_levels']
- ).cuda()
- predictor = Predictor(model).cuda()
+ )
+ predictor = Predictor(model)
+ if params['cuda']:
+ model = model.cuda()
+ predictor = predictor.cuda()
optimizer = gradient_clipping(torch.optim.Adam(predictor.parameters()))
@@ -163,7 +167,7 @@ def main(exp, frame_sizes, dataset, **params):
trainer = Trainer(
predictor, sequence_nll_loss_bits, optimizer,
data_loader(0, val_split, eval=False),
- cuda=True
+ cuda=params['cuda']
)
checkpoints_path = os.path.join(results_path, 'checkpoints')
@@ -310,6 +314,10 @@ if __name__ == '__main__':
help='smoothing parameter of the exponential moving average over \
training loss, used in the log and in the loss plot'
)
+ parser.add_argument(
+ '--cuda', type=parse_bool,
+ help='whether to use CUDA'
+ )
parser.set_defaults(**default_params)