diff options
| -rw-r--r-- | dataset.py | 2 | ||||
| -rw-r--r-- | train.py | 16 |
2 files changed, 13 insertions, 5 deletions
@@ -58,7 +58,7 @@ class DataLoader(DataLoaderBase): to_index = seq_begin + self.seq_len sequences = batch[:, from_index : to_index] input_sequences = sequences[:, : -1] - target_sequences = sequences[:, self.overlap_len :] + target_sequences = sequences[:, self.overlap_len :].contiguous() yield (input_sequences, reset, target_sequences) @@ -42,7 +42,8 @@ default_params = { 'sample_rate': 16000, 'n_samples': 1, 'sample_length': 80000, - 'loss_smoothing': 0.99 + 'loss_smoothing': 0.99, + 'cuda': True } tag_params = [ @@ -151,8 +152,11 @@ def main(exp, frame_sizes, dataset, **params): dim=params['dim'], learn_h0=params['learn_h0'], q_levels=params['q_levels'] - ).cuda() - predictor = Predictor(model).cuda() + ) + predictor = Predictor(model) + if params['cuda']: + model = model.cuda() + predictor = predictor.cuda() optimizer = gradient_clipping(torch.optim.Adam(predictor.parameters())) @@ -163,7 +167,7 @@ def main(exp, frame_sizes, dataset, **params): trainer = Trainer( predictor, sequence_nll_loss_bits, optimizer, data_loader(0, val_split, eval=False), - cuda=True + cuda=params['cuda'] ) checkpoints_path = os.path.join(results_path, 'checkpoints') @@ -310,6 +314,10 @@ if __name__ == '__main__': help='smoothing parameter of the exponential moving average over \ training loss, used in the log and in the loss plot' ) + parser.add_argument( + '--cuda', type=parse_bool, + help='whether to use CUDA' + ) parser.set_defaults(**default_params) |
