diff options
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 16 |
1 files changed, 12 insertions, 4 deletions
@@ -42,7 +42,8 @@ default_params = { 'sample_rate': 16000, 'n_samples': 1, 'sample_length': 80000, - 'loss_smoothing': 0.99 + 'loss_smoothing': 0.99, + 'cuda': True } tag_params = [ @@ -151,8 +152,11 @@ def main(exp, frame_sizes, dataset, **params): dim=params['dim'], learn_h0=params['learn_h0'], q_levels=params['q_levels'] - ).cuda() - predictor = Predictor(model).cuda() + ) + predictor = Predictor(model) + if params['cuda']: + model = model.cuda() + predictor = predictor.cuda() optimizer = gradient_clipping(torch.optim.Adam(predictor.parameters())) @@ -163,7 +167,7 @@ def main(exp, frame_sizes, dataset, **params): trainer = Trainer( predictor, sequence_nll_loss_bits, optimizer, data_loader(0, val_split, eval=False), - cuda=True + cuda=params['cuda'] ) checkpoints_path = os.path.join(results_path, 'checkpoints') @@ -310,6 +314,10 @@ if __name__ == '__main__': help='smoothing parameter of the exponential moving average over \ training loss, used in the log and in the loss plot' ) + parser.add_argument( + '--cuda', type=parse_bool, + help='whether to use CUDA' + ) parser.set_defaults(**default_params) |
