summaryrefslogtreecommitdiff
path: root/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'utils.py')
-rw-r--r--utils.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..320fe95
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,20 @@
+import torch
+from torch import nn
+import numpy as np
+
+
+EPSILON = 1e-5
+
+def linear_quantize(samples, q_levels):
+ samples = samples.clone()
+ samples -= samples.min(dim=-1)[0].expand_as(samples)
+ samples /= samples.max(dim=-1)[0].expand_as(samples)
+ samples *= q_levels - EPSILON
+ samples += EPSILON / 2
+ return samples.long()
+
+def linear_dequantize(samples, q_levels):
+ return samples.float() / (q_levels / 2) - 1
+
+def q_zero(q_levels):
+ return q_levels // 2