diff options
| author | Piotr Kozakowski <kozak000@gmail.com> | 2017-11-19 20:23:26 +0100 |
|---|---|---|
| committer | Piotr Kozakowski <kozak000@gmail.com> | 2017-11-19 20:23:26 +0100 |
| commit | 4167442627b1414ff8fdc86528812b46168c656b (patch) | |
| tree | f5020d2161762fad2db56f3f9ddcb3ad2deec553 /train.py | |
| parent | 61e935ff5a90c8c7b9a5a5f2f54d4ec8f9742dc0 (diff) | |
Add weight normalization
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 8 |
1 files changed, 7 insertions, 1 deletions
@@ -35,6 +35,7 @@ default_params = { 'learn_h0': True, 'q_levels': 256, 'seq_len': 1024, + 'weight_norm': True, 'batch_size': 128, 'val_frac': 0.1, 'test_frac': 0.1, @@ -175,7 +176,8 @@ def main(exp, frame_sizes, dataset, **params): n_rnn=params['n_rnn'], dim=params['dim'], learn_h0=params['learn_h0'], - q_levels=params['q_levels'] + q_levels=params['q_levels'], + weight_norm=params['weight_norm'] ) predictor = Predictor(model) if params['cuda']: @@ -300,6 +302,10 @@ if __name__ == '__main__': '--seq_len', type=int, help='how many samples to include in each truncated BPTT pass' ) + parser.add_argument( + '--weight_norm', type=parse_bool, + help='whether to use weight normalization' + ) parser.add_argument('--batch_size', type=int, help='batch size') parser.add_argument( '--val_frac', type=float, |
