diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-05-14 19:15:58 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-05-14 19:15:58 +0200 |
| commit | 60fb2b7c87b7e6aa179c6a973a8d6e39cbe7c594 (patch) | |
| tree | 8a738a43e8583f38f151cdc643a38b5a9437cda2 /generate.py | |
| parent | 9766ef0f3a539be7ee68bb93918f25a3298afe39 (diff) | |
| parent | e2d8a6f26c5e44d970d7c069f171105376835495 (diff) | |
Merge branch 'master' of asdf.us:samplernn
Diffstat (limited to 'generate.py')
| -rw-r--r-- | generate.py | 365 |
1 files changed, 365 insertions, 0 deletions
diff --git a/generate.py b/generate.py new file mode 100644 index 0000000..92a930f --- /dev/null +++ b/generate.py @@ -0,0 +1,365 @@ +# CometML needs to be imported first. +try: + import comet_ml +except ImportError: + pass + +from model import SampleRNN, Predictor +from optim import gradient_clipping +from nn import sequence_nll_loss_bits +from trainer import Trainer +from trainer.plugins import ( + TrainingLossMonitor, ValidationPlugin, AbsoluteTimeMonitor, SaverPlugin, + GeneratorPlugin, StatsPlugin +) +from dataset import FolderDataset, DataLoader + +import torch +from torch.utils.trainer.plugins import Logger + +from natsort import natsorted + +from functools import reduce +import os +import shutil +import sys +from glob import glob +import re +import argparse + + +default_params = { + # model parameters + 'n_rnn': 1, + 'dim': 1024, + 'learn_h0': True, + 'q_levels': 256, + 'seq_len': 1024, + 'weight_norm': True, + 'batch_size': 128, + 'val_frac': 0.1, + 'test_frac': 0.1, + + # training parameters + 'keep_old_checkpoints': False, + 'datasets_path': 'datasets', + 'results_path': 'results', + 'epoch_limit': 1000, + 'resume': True, + 'sample_rate': 16000, + 'n_samples': 1, + 'sample_length': 80000, + 'loss_smoothing': 0.99, + 'cuda': True, + 'comet_key': None +} + +tag_params = [ + 'exp', 'frame_sizes', 'n_rnn', 'dim', 'learn_h0', 'q_levels', 'seq_len', + 'batch_size', 'dataset', 'val_frac', 'test_frac' +] + +def param_to_string(value): + if isinstance(value, bool): + return 'T' if value else 'F' + elif isinstance(value, list): + return ','.join(map(param_to_string, value)) + else: + return str(value) + +def make_tag(params): + return '-'.join( + key + ':' + param_to_string(params[key]) + for key in tag_params + if key not in default_params or params[key] != default_params[key] + ) + +def setup_results_dir(params): + def ensure_dir_exists(path): + if not os.path.exists(path): + os.makedirs(path) + + tag = make_tag(params) + results_path = os.path.abspath(params['results_path']) + ensure_dir_exists(results_path) + results_path = os.path.join(results_path, tag) + if not os.path.exists(results_path): + os.makedirs(results_path) + elif not params['resume']: + shutil.rmtree(results_path) + os.makedirs(results_path) + + for subdir in ['checkpoints', 'samples']: + ensure_dir_exists(os.path.join(results_path, subdir)) + + return results_path + +def load_last_checkpoint(checkpoints_path): + checkpoints_pattern = os.path.join( + checkpoints_path, SaverPlugin.last_pattern.format('*', '*') + ) + checkpoint_paths = natsorted(glob(checkpoints_pattern)) + if len(checkpoint_paths) > 0: + checkpoint_path = checkpoint_paths[-1] + checkpoint_name = os.path.basename(checkpoint_path) + match = re.match( + SaverPlugin.last_pattern.format(r'(\d+)', r'(\d+)'), + checkpoint_name + ) + epoch = int(match.group(1)) + iteration = int(match.group(2)) + return (torch.load(checkpoint_path), epoch, iteration) + else: + return None + +def tee_stdout(log_path): + log_file = open(log_path, 'a', 1) + stdout = sys.stdout + + class Tee: + + def write(self, string): + log_file.write(string) + stdout.write(string) + + def flush(self): + log_file.flush() + stdout.flush() + + sys.stdout = Tee() + +def make_data_loader(overlap_len, params): + path = os.path.join(params['datasets_path'], params['dataset']) + def data_loader(split_from, split_to, eval): + dataset = FolderDataset( + path, overlap_len, params['q_levels'], split_from, split_to + ) + return DataLoader( + dataset, + batch_size=params['batch_size'], + seq_len=params['seq_len'], + overlap_len=overlap_len, + shuffle=(not eval), + drop_last=(not eval) + ) + return data_loader + +def init_comet(params, trainer): + if params['comet_key'] is not None: + from comet_ml import Experiment + from trainer.plugins import CometPlugin + experiment = Experiment(api_key=params['comet_key'], log_code=False) + hyperparams = { + name: param_to_string(params[name]) for name in tag_params + } + experiment.log_multiple_params(hyperparams) + trainer.register_plugin(CometPlugin( + experiment, [ + ('training_loss', 'epoch_mean'), + 'validation_loss', + 'test_loss' + ] + )) + +def main(exp, frame_sizes, dataset, **params): + params = dict( + default_params, + exp=exp, frame_sizes=frame_sizes, dataset=dataset, + **params + ) + + results_path = setup_results_dir(params) + tee_stdout(os.path.join(results_path, 'log')) + + model = SampleRNN( + frame_sizes=params['frame_sizes'], + n_rnn=params['n_rnn'], + dim=params['dim'], + learn_h0=params['learn_h0'], + q_levels=params['q_levels'], + weight_norm=params['weight_norm'] + ) + predictor = Predictor(model) + if params['cuda']: + model = model.cuda() + predictor = predictor.cuda() + + optimizer = gradient_clipping(torch.optim.Adam(predictor.parameters())) + + data_loader = make_data_loader(model.lookback, params) + test_split = 1 - params['test_frac'] + val_split = test_split - params['val_frac'] + + trainer = Trainer( + predictor, sequence_nll_loss_bits, optimizer, + data_loader(0, val_split, eval=False), + cuda=params['cuda'] + ) + + checkpoints_path = os.path.join(results_path, 'checkpoints') + checkpoint_data = load_last_checkpoint(checkpoints_path) + if checkpoint_data is not None: + (state_dict, epoch, iteration) = checkpoint_data + trainer.epochs = epoch + trainer.iterations = iteration + predictor.load_state_dict(state_dict) + print("epochs: {} iterations: {}".format(epoch, iteration)) + + """ + trainer.register_plugin(TrainingLossMonitor( + smoothing=params['loss_smoothing'] + )) + """ + trainer.register_plugin(ValidationPlugin( + data_loader(val_split, test_split, eval=True), + data_loader(test_split, 1, eval=True) + )) + trainer.register_plugin(AbsoluteTimeMonitor()) + """ + trainer.register_plugin(SaverPlugin( + checkpoints_path, params['keep_old_checkpoints'] + )) + """ + trainer.register_plugin(GeneratorPlugin( + os.path.join(results_path, 'samples'), params['n_samples'], + params['sample_length'], params['sample_rate'] + )) + """ + trainer.register_plugin( + Logger([ + 'training_loss', + 'validation_loss', + 'test_loss', + 'time' + ]) + ) + trainer.register_plugin(StatsPlugin( + results_path, + iteration_fields=[ + 'training_loss', + ('training_loss', 'running_avg'), + 'time' + ], + epoch_fields=[ + 'validation_loss', + 'test_loss', + 'time' + ], + plots={ + 'loss': { + 'x': 'iteration', + 'ys': [ + 'training_loss', + ('training_loss', 'running_avg'), + 'validation_loss', + 'test_loss', + ], + 'log_y': True + } + } + )) + init_comet(params, trainer) + """ + trainer.generate(int(params['epoch_limit'])) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + argument_default=argparse.SUPPRESS + ) + + def parse_bool(arg): + arg = arg.lower() + if 'true'.startswith(arg): + return True + elif 'false'.startswith(arg): + return False + else: + raise ValueError() + + parser.add_argument('--exp', required=True, help='experiment name') + parser.add_argument( + '--frame_sizes', nargs='+', type=int, required=True, + help='frame sizes in terms of the number of lower tier frames, \ + starting from the lowest RNN tier' + ) + parser.add_argument( + '--dataset', required=True, + help='dataset name - name of a directory in the datasets path \ + (settable by --datasets_path)' + ) + parser.add_argument( + '--n_rnn', type=int, help='number of RNN layers in each tier' + ) + parser.add_argument( + '--dim', type=int, help='number of neurons in every RNN and MLP layer' + ) + parser.add_argument( + '--learn_h0', type=parse_bool, + help='whether to learn the initial states of RNNs' + ) + parser.add_argument( + '--q_levels', type=int, + help='number of bins in quantization of audio samples' + ) + parser.add_argument( + '--seq_len', type=int, + help='how many samples to include in each truncated BPTT pass' + ) + parser.add_argument( + '--weight_norm', type=parse_bool, + help='whether to use weight normalization' + ) + parser.add_argument('--batch_size', type=int, help='batch size') + parser.add_argument( + '--val_frac', type=float, + help='fraction of data to go into the validation set' + ) + parser.add_argument( + '--test_frac', type=float, + help='fraction of data to go into the test set' + ) + parser.add_argument( + '--keep_old_checkpoints', type=parse_bool, + help='whether to keep checkpoints from past epochs' + ) + parser.add_argument( + '--datasets_path', help='path to the directory containing datasets' + ) + parser.add_argument( + '--results_path', help='path to the directory to save the results to' + ) + parser.add_argument('--epoch_limit', help='how many epochs to run') + parser.add_argument( + '--resume', type=parse_bool, default=True, + help='whether to resume training from the last checkpoint' + ) + parser.add_argument( + '--sample_rate', type=int, + help='sample rate of the training data and generated sound' + ) + parser.add_argument( + '--n_samples', type=int, + help='number of samples to generate in each epoch' + ) + parser.add_argument( + '--sample_length', type=int, + help='length of each generated sample (in samples)' + ) + parser.add_argument( + '--loss_smoothing', type=float, + help='smoothing parameter of the exponential moving average over \ + training loss, used in the log and in the loss plot' + ) + parser.add_argument( + '--cuda', type=parse_bool, + help='whether to use CUDA' + ) + parser.add_argument( + '--comet_key', help='comet.ml API key' + ) + + parser.set_defaults(**default_params) + + main(**vars(parser.parse_args())) |
