blob: 183ce537df6ed50b68b4c09d6276f1d4fed57eb8 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
import torch
from torch import nn
import numpy as np
EPSILON = 1e-2
def linear_quantize(samples, q_levels):
samples = samples.clone()
samples -= samples.min(dim=-1)[0].expand_as(samples)
samples /= samples.max(dim=-1)[0].expand_as(samples)
samples *= q_levels - EPSILON
samples += EPSILON / 2
return samples.long()
def linear_dequantize(samples, q_levels):
return samples.float() / (q_levels / 2) - 1
def q_zero(q_levels):
return q_levels // 2
|