Skip to content

Commit

Permalink
fix observer bugs (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran authored Apr 16, 2024
1 parent 2494e46 commit 20283a0
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/sparsetensors/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .quant_args import *
from .quant_config import *
from .quant_scheme import *
from .lifecycle import *
2 changes: 1 addition & 1 deletion src/sparsetensors/quantization/observers/memoryless.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
# scales from a 0 range should be set to 1
scale[observed_range == 0] = 1

zero_point = (0 - min_val) / scale
zero_point = ((0 - min_val) / scale).to(torch.int8)

return scale, zero_point
12 changes: 6 additions & 6 deletions src/sparsetensors/quantization/observers/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
min_val = torch.tensor([observed.min()])
max_val = torch.tensor([observed.max()])

# running average
# update running average
if self.counter > 0:
self.min_val = (self.min_val * self.counter + min_val) / (self.counter + 1)
self.max_val = (self.max_val * self.counter + max_val) / (self.counter + 1)
Expand All @@ -57,23 +57,23 @@ def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
self.max_val = max_val

# ensure that the zeros are in the range
self.min_val = torch.min(self.min_val, torch.zeros_like(self.min_val))
self.max_val = torch.max(self.max_val, torch.zeros_like(self.max_val))
min_val = torch.min(self.min_val, torch.zeros_like(self.min_val))
max_val = torch.max(self.max_val, torch.zeros_like(self.max_val))

self.counter += 1

if self.quantization_args.symmetric:
symmetric_range = 2 * max(self.min_val.abs(), self.max_val.abs())
symmetric_range = 2 * max(min_val.abs(), max_val.abs())
scale = symmetric_range / bit_range
zero_point = torch.tensor(0).to(torch.int8)
else:
# non-symmetric
observed_range = self.max_val - self.min_val
observed_range = max_val - min_val
scale = observed_range / bit_range

# scales from a 0 range should be set to 1
scale[observed_range == 0] = 1

zero_point = (0 - self.min_val) / scale
zero_point = ((0 - min_val) / scale).to(torch.int8)

return scale, zero_point

0 comments on commit 20283a0

Please sign in to comment.