diff options
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 316 |
1 files changed, 316 insertions, 0 deletions
diff --git a/train.py b/train.py new file mode 100644 index 0000000..c47cd4b --- /dev/null +++ b/train.py @@ -0,0 +1,316 @@ +from model import SampleRNN, Predictor +from optim import gradient_clipping +from nn import sequence_nll_loss_bits +from trainer import Trainer +from trainer.plugins import ( + TrainingLossMonitor, ValidationPlugin, AbsoluteTimeMonitor, SaverPlugin, + GeneratorPlugin, StatsPlugin +) +from dataset import FolderDataset, DataLoader + +import torch +from torch.utils.trainer.plugins import Logger + +from natsort import natsorted + +from functools import reduce +import os +import shutil +import sys +from glob import glob +import re +import argparse + + +default_params = { + # model parameters + 'n_rnn': 1, + 'dim': 1024, + 'learn_h0': True, + 'q_levels': 256, + 'seq_len': 1024, + 'batch_size': 128, + 'val_frac': 0.1, + 'test_frac': 0.1, + + # training parameters + 'keep_old_checkpoints': False, + 'datasets_path': 'datasets', + 'results_path': 'results', + 'epoch_limit': 1000, + 'resume': True, + 'sample_rate': 16000, + 'n_samples': 1, + 'sample_length': 80000, + 'loss_smoothing': 0.99 +} + +tag_params = [ + 'exp', 'frame_sizes', 'n_rnn', 'dim', 'learn_h0', 'q_levels', 'seq_len', + 'batch_size', 'dataset', 'val_frac', 'test_frac' +] + +def make_tag(params): + def to_string(value): + if isinstance(value, bool): + return 'T' if value else 'F' + elif isinstance(value, list): + return ','.join(map(to_string, value)) + else: + return str(value) + + return '-'.join( + key + ':' + to_string(params[key]) + for key in tag_params + if key not in default_params or params[key] != default_params[key] + ) + +def setup_results_dir(params): + def ensure_dir_exists(path): + if not os.path.exists(path): + os.makedirs(path) + + tag = make_tag(params) + results_path = os.path.abspath(params['results_path']) + ensure_dir_exists(results_path) + results_path = os.path.join(results_path, tag) + if not os.path.exists(results_path): + os.makedirs(results_path) + elif not params['resume']: + shutil.rmtree(results_path) + os.makedirs(results_path) + + for subdir in ['checkpoints', 'samples']: + ensure_dir_exists(os.path.join(results_path, subdir)) + + return results_path + +def load_last_checkpoint(checkpoints_path): + checkpoints_pattern = os.path.join( + checkpoints_path, SaverPlugin.last_pattern.format('*', '*') + ) + checkpoint_paths = natsorted(glob(checkpoints_pattern)) + if len(checkpoint_paths) > 0: + checkpoint_path = checkpoint_paths[-1] + checkpoint_name = os.path.basename(checkpoint_path) + match = re.match( + SaverPlugin.last_pattern.format(r'(\d+)', r'(\d+)'), + checkpoint_name + ) + epoch = int(match.group(1)) + iteration = int(match.group(2)) + return (torch.load(checkpoint_path), epoch, iteration) + else: + return None + +def tee_stdout(log_path): + log_file = open(log_path, 'a', 1) + stdout = sys.stdout + + class Tee: + + def write(self, string): + log_file.write(string) + stdout.write(string) + + def flush(self): + log_file.flush() + stdout.flush() + + sys.stdout = Tee() + +def make_data_loader(overlap_len, params): + path = os.path.join(params['datasets_path'], params['dataset']) + def data_loader(split_from, split_to, eval): + dataset = FolderDataset( + path, overlap_len, params['q_levels'], split_from, split_to + ) + return DataLoader( + dataset, + batch_size=params['batch_size'], + seq_len=params['seq_len'], + overlap_len=overlap_len, + shuffle=(not eval), + drop_last=(not eval) + ) + return data_loader + +def main(exp, frame_sizes, dataset, **params): + params = dict( + default_params, + exp=exp, frame_sizes=frame_sizes, dataset=dataset, + **params + ) + + results_path = setup_results_dir(params) + tee_stdout(os.path.join(results_path, 'log')) + + model = SampleRNN( + frame_sizes=params['frame_sizes'], + n_rnn=params['n_rnn'], + dim=params['dim'], + learn_h0=params['learn_h0'], + q_levels=params['q_levels'] + ).cuda() + predictor = Predictor(model).cuda() + + optimizer = gradient_clipping(torch.optim.Adam(predictor.parameters())) + + data_loader = make_data_loader(model.lookback, params) + test_split = 1 - params['test_frac'] + val_split = test_split - params['val_frac'] + + trainer = Trainer( + predictor, sequence_nll_loss_bits, optimizer, + data_loader(0, val_split, eval=False), + cuda=True + ) + + checkpoints_path = os.path.join(results_path, 'checkpoints') + checkpoint_data = load_last_checkpoint(checkpoints_path) + if checkpoint_data is not None: + (state_dict, epoch, iteration) = checkpoint_data + trainer.epochs = epoch + trainer.iterations = iteration + predictor.load_state_dict(state_dict) + + trainer.register_plugin(TrainingLossMonitor( + smoothing=params['loss_smoothing'] + )) + trainer.register_plugin(ValidationPlugin( + data_loader(val_split, test_split, eval=True), + data_loader(test_split, 1, eval=True) + )) + trainer.register_plugin(AbsoluteTimeMonitor()) + trainer.register_plugin(SaverPlugin( + checkpoints_path, params['keep_old_checkpoints'] + )) + trainer.register_plugin(GeneratorPlugin( + os.path.join(results_path, 'samples'), params['n_samples'], + params['sample_length'], params['sample_rate'] + )) + trainer.register_plugin( + Logger([ + 'training_loss', + 'validation_loss', + 'test_loss', + 'time' + ]) + ) + trainer.register_plugin(StatsPlugin( + results_path, + iteration_fields=[ + 'training_loss', + ('training_loss', 'running_avg'), + 'time' + ], + epoch_fields=[ + 'validation_loss', + 'test_loss', + 'time' + ], + plots={ + 'loss': { + 'x': 'iteration', + 'ys': [ + 'training_loss', + ('training_loss', 'running_avg'), + 'validation_loss', + 'test_loss', + ], + 'log_y': True + } + } + )) + trainer.run(params['epoch_limit']) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + argument_default=argparse.SUPPRESS + ) + + def parse_bool(arg): + arg = arg.lower() + if 'true'.startswith(arg): + return True + elif 'false'.startswith(arg): + return False + else: + raise ValueError() + + parser.add_argument('--exp', required=True, help='experiment name') + parser.add_argument( + '--frame_sizes', nargs='+', type=int, required=True, + help='frame sizes in terms of the number of lower tier frames, \ + starting from the lowest RNN tier' + ) + parser.add_argument( + '--dataset', required=True, + help='dataset name - name of a directory in the datasets path \ + (settable by --datasets_path)' + ) + parser.add_argument( + '--n_rnn', type=int, help='number of RNN layers in each tier' + ) + parser.add_argument( + '--dim', type=int, help='number of neurons in every RNN and MLP layer' + ) + parser.add_argument( + '--learn_h0', type=parse_bool, + help='whether to learn the initial states of RNNs' + ) + parser.add_argument( + '--q_levels', type=int, + help='number of bins in quantization of audio samples' + ) + parser.add_argument( + '--seq_len', type=int, + help='how many samples to include in each truncated BPTT pass' + ) + parser.add_argument('--batch_size', type=int, help='batch size') + parser.add_argument( + '--val_frac', type=float, + help='fraction of data to go into the validation set' + ) + parser.add_argument( + '--test_frac', type=float, + help='fraction of data to go into the test set' + ) + parser.add_argument( + '--keep_old_checkpoints', type=parse_bool, + help='whether to keep checkpoints from past epochs' + ) + parser.add_argument( + '--datasets_path', help='path to the directory containing datasets' + ) + parser.add_argument( + '--results_path', help='path to the directory to save the results to' + ) + parser.add_argument('--epoch_limit', help='how many epochs to run') + parser.add_argument( + '--resume', type=parse_bool, default=True, + help='whether to resume training from the last checkpoint' + ) + parser.add_argument( + '--sample_rate', type=int, + help='sample rate of the training data and generated sound' + ) + parser.add_argument( + '--n_samples', type=int, + help='number of samples to generate in each epoch' + ) + parser.add_argument( + '--sample_length', type=int, + help='length of each generated sample (in samples)' + ) + parser.add_argument( + '--loss_smoothing', type=float, + help='smoothing parameter of the exponential moving average over \ + training loss, used in the log and in the loss plot' + ) + + parser.set_defaults(**default_params) + + main(**vars(parser.parse_args())) |
