summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/train.py b/train.py
index 8bab87b..e3061c9 100644
--- a/train.py
+++ b/train.py
@@ -35,6 +35,7 @@ default_params = {
'learn_h0': True,
'q_levels': 256,
'seq_len': 1024,
+ 'weight_norm': True,
'batch_size': 128,
'val_frac': 0.1,
'test_frac': 0.1,
@@ -175,7 +176,8 @@ def main(exp, frame_sizes, dataset, **params):
n_rnn=params['n_rnn'],
dim=params['dim'],
learn_h0=params['learn_h0'],
- q_levels=params['q_levels']
+ q_levels=params['q_levels'],
+ weight_norm=params['weight_norm']
)
predictor = Predictor(model)
if params['cuda']:
@@ -300,6 +302,10 @@ if __name__ == '__main__':
'--seq_len', type=int,
help='how many samples to include in each truncated BPTT pass'
)
+ parser.add_argument(
+ '--weight_norm', type=parse_bool,
+ help='whether to use weight normalization'
+ )
parser.add_argument('--batch_size', type=int, help='batch size')
parser.add_argument(
'--val_frac', type=float,