Skip to content

Commit

Permalink
chore: remove old quantization impl
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Aug 30, 2024
1 parent 1b776e2 commit b85ade6
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 223 deletions.
12 changes: 3 additions & 9 deletions crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
pub mod gemm;

use crate::{
Binary, BinaryOp, CPUBuffer, CPUOperation, Cast, DType, IndexSelect, InvariantError, OpGuards,
Operation, OperationError, Quantization, Quantizer, RVec, Storage, StorageView, Tensor,
dequantize, Binary, BinaryOp, CPUBuffer, CPUOperation, Cast, DType, IndexSelect,
InvariantError, OpGuards, Operation, OperationError, RVec, Storage, StorageView, Tensor,
TensorDType, Unary, UnaryOp,
};
use anyhow::anyhow;
Expand Down Expand Up @@ -228,13 +228,7 @@ fn qindex_select(op: IndexSelect, dst: Tensor) -> Result<Tensor, OperationError>
let src = op.src().deep_clone();

// NOTE: Support for other quantization types is dependent on the corresponding dequantization functions.
let src = match src.dt() {
DType::Q8_0F(_) => {
let quantizer = Quantizer::new(Quantization::SInt8);
quantizer.sint8_dequantize(src)
}
_ => return Err(InvariantError::UnsupportedDType(src.dt()).into()),
};
let src = dequantize(src);
let indices = op.indices().clone();
let dim = op.dim();

Expand Down
6 changes: 3 additions & 3 deletions crates/ratchet-core/src/ops/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,10 @@ def index_select(input, indices):
}

#[test]
fn qindex_select() {
fn test_qindex_select() {
let prob = IndexSelectProblem {
input_shape: shape![52000, 1280],
indices: Tensor::from_data(vec![50258, 50259, 50360], shape![3], Device::CPU),
input_shape: shape![256, 32],
indices: Tensor::from_data(vec![64, 192, 255], shape![3], Device::CPU),
};
let device = Device::request_device(DeviceRequest::GPU).unwrap();
run_index_select_trial(prob.clone(), device, true);
Expand Down
213 changes: 3 additions & 210 deletions crates/ratchet-core/src/quant.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,8 @@
use num::integer::div_floor;
use num_traits::{AsPrimitive, Float, FromPrimitive, Zero};

use std::fmt::Debug;

use crate::{
dtype::Quantized, gpu::STORAGE_BUFFER_ALIGN, DType, Device, Tensor, Q4_KF, Q4_KH, Q8_0F, Q8_0H,
};

/// Quantizer
///
/// Packs weights into our custom quantization formats.
#[derive(Debug, derive_new::new)]
pub struct Quantizer {
format: Quantization,
}
use num::integer::div_floor;
use num_traits::{AsPrimitive, Float, FromPrimitive, Zero};

#[inline]
fn storage_align<T>(n: usize) -> usize {
Expand Down Expand Up @@ -174,186 +163,10 @@ pub fn dequantize(quantized: Tensor) -> Tensor {
};
}

impl Quantizer {
/// Quantizes a float 32 tensor into a packed uint32 tensor.
pub fn sint8_quantize(&self, tensor: Tensor) -> Tensor {
let numel = tensor.shape().numel();
let pack_size = self.format.pack_size();
let group_size = self.format.group_size();

assert!(numel % pack_size == 0 && numel % group_size == 0);
assert!(tensor.dt() == DType::F32); //TODO: f16, bf16
//TODO: check if tensor is contiguous
let qmatrix_len = numel / pack_size;
let amatrix_len = numel / group_size;

//returns the aligned number of ELEMENTS
let aligner = |numel: usize, size_t: usize| -> usize {
let nbytes = numel * size_t;
let aligned = if nbytes % STORAGE_BUFFER_ALIGN != 0 {
nbytes + STORAGE_BUFFER_ALIGN - nbytes % STORAGE_BUFFER_ALIGN
} else {
nbytes
};
aligned / size_t
};

let mut quantized_matrix = vec![0u32; aligner(qmatrix_len, std::mem::size_of::<u32>())];
let mut absmax_matrix = vec![0f32; aligner(amatrix_len, std::mem::size_of::<f32>())];

let mut block_absmax = f32::NEG_INFINITY;

let matrix = tensor.to_vec::<f32>().unwrap();

for i in (0..numel).step_by(pack_size) {
if i % group_size == 0 {
let amax = matrix[i..i + group_size]
.iter()
.fold(f32::NEG_INFINITY, |acc, &x| acc.max(x.abs()));
let d = amax / ((1 << 7) - 1) as f32;
block_absmax = d;
}
let packed_value: i32 = ((matrix[i] / block_absmax).round() as i32 & 0xFF)
| (((matrix[i + 1] / block_absmax).round() as i32 & 0xFF) << 8)
| (((matrix[i + 2] / block_absmax).round() as i32 & 0xFF) << 16)
| (((matrix[i + 3] / block_absmax).round() as i32 & 0xFF) << 24);
quantized_matrix[i / pack_size] = packed_value as u32;
absmax_matrix[i / group_size] = block_absmax;
}
quantized_matrix.append(&mut unsafe { std::mem::transmute(absmax_matrix) });
unsafe {
Tensor::from_quantized(
quantized_matrix,
DType::Q8_0F(Q8_0F::default()),
tensor.shape().clone(),
Device::CPU,
)
}
}

pub fn sint8_dequantize(&self, quantized: Tensor) -> Tensor {
assert!(matches!(quantized.dt(), DType::Q8_0F(_)));
let numel = quantized.shape().numel();
let original_shape = quantized.shape().clone();
let aligner = |numel: usize, size_t: usize| -> usize {
let nbytes = numel * size_t;

if nbytes % STORAGE_BUFFER_ALIGN != 0 {
nbytes + STORAGE_BUFFER_ALIGN - nbytes % STORAGE_BUFFER_ALIGN
} else {
nbytes
}
};

let pack_size = self.format.pack_size();
let group_size = self.format.group_size();

let num_q = numel / pack_size;
let num_q_bytes = num_q * std::mem::size_of::<u32>();
let aligned_q_bytes = aligner(num_q, std::mem::size_of::<u32>());

let num_absmax = numel / group_size;
let num_absmax_bytes = num_absmax * std::mem::size_of::<f32>();

let raw_bytes = unsafe { quantized.into_bytes().unwrap() };

let quantized_matrix = bytemuck::cast_slice::<u8, u32>(&raw_bytes[..num_q_bytes]);
let absmax_matrix = bytemuck::cast_slice::<u8, f32>(
&raw_bytes[aligned_q_bytes..aligned_q_bytes + num_absmax_bytes],
);

let mut dequantized = vec![0.0f32; numel];

for i in (0..numel).step_by(pack_size) {
let block_absmax = absmax_matrix[div_floor(i, group_size)];
let packed_value = quantized_matrix[div_floor(i, pack_size)] as i32;
dequantized[i] = ((packed_value << 24) >> 24) as f32 * block_absmax;
dequantized[i + 1] = ((packed_value << 16) >> 24) as f32 * block_absmax;
dequantized[i + 2] = ((packed_value << 8) >> 24) as f32 * block_absmax;
dequantized[i + 3] = (packed_value >> 24) as f32 * block_absmax;
}

Tensor::from_data(dequantized, original_shape, Device::CPU)
}

pub fn sint4_quantize<F: Float + AsPrimitive<i32> + Debug>(
matrix: &[F],
K: usize,
N: usize,
) -> (Vec<u32>, F) {
assert!(matrix.len() == K * N);
assert!(matrix.len() % 4 == 0);
assert!(matrix.len() % 32 == 0);
let pack_size = 8;
let mut quantized_matrix = vec![0u32; K * N / pack_size];

let absmax = matrix.iter().fold(F::zero(), |acc, &x| acc.max(x.abs()));
let sf = F::from(7.).unwrap();

for i in (0..(K * N)).step_by(pack_size) {
let packed_value: i32 = ((matrix[i] / absmax * sf).round().as_() & 0xF)
| (((matrix[i + 1] / absmax * sf).round().as_() & 0xF) << 4)
| (((matrix[i + 2] / absmax * sf).round().as_() & 0xF) << 8)
| (((matrix[i + 3] / absmax * sf).round().as_() & 0xF) << 12)
| (((matrix[i + 4] / absmax * sf).round().as_() & 0xF) << 16)
| (((matrix[i + 5] / absmax * sf).round().as_() & 0xF) << 20)
| (((matrix[i + 6] / absmax * sf).round().as_() & 0xF) << 24)
| (((matrix[i + 7] / absmax * sf).round().as_() & 0xF) << 28);
quantized_matrix[i / pack_size] = packed_value as u32
}
(quantized_matrix, absmax)
}

pub fn sint4_dequantize(quantized_matrix: &[u32], absmax: f32, K: usize, N: usize) -> Vec<f32> {
let pack_size = 8;
let mut matrix = vec![0.0; K * N];

for i in (0..(K * N)).step_by(pack_size) {
let packed_value = quantized_matrix[div_floor(i, pack_size)] as i32;
matrix[i] = ((packed_value << 28) >> 28) as f32 / 7.0 * absmax;
matrix[i + 1] = ((packed_value << 24) >> 28) as f32 / 7.0 * absmax;
matrix[i + 2] = ((packed_value << 20) >> 28) as f32 / 7.0 * absmax;
matrix[i + 3] = ((packed_value << 16) >> 28) as f32 / 7.0 * absmax;
matrix[i + 4] = ((packed_value << 12) >> 28) as f32 / 7.0 * absmax;
matrix[i + 5] = ((packed_value << 8) >> 28) as f32 / 7.0 * absmax;
matrix[i + 6] = ((packed_value << 4) >> 28) as f32 / 7.0 * absmax;
matrix[i + 7] = (packed_value >> 28) as f32 / 7.0 * absmax;
}

matrix
}
}

#[derive(Debug, Clone, Copy)]
pub enum Quantization {
None,
SInt8,
SInt4,
}

impl Quantization {
pub fn pack_size(&self) -> usize {
match self {
Quantization::None => 1,
Quantization::SInt8 => 4,
Quantization::SInt4 => 8,
}
}

pub fn group_size(&self) -> usize {
match self {
Quantization::None => 1,
Quantization::SInt8 => 32,
Quantization::SInt4 => 8,
}
}
}

#[cfg(test)]
mod tests {
use crate::{
dequantize, quantize, shape, Device, Quantization, Quantized, Quantizer, Tensor, Q4_KF,
Q4_KH, Q8_0F, Q8_0H,
dequantize, quantize, shape, Device, Quantized, Tensor, Q4_KF, Q4_KH, Q8_0F, Q8_0H,
};
use half::f16;

Expand All @@ -375,24 +188,4 @@ mod tests {
check_qd_reflexive::<Q4_KF>(0.3, 0.3);
check_qd_reflexive::<Q4_KH>(f16::from_f32(0.3), f16::from_f32(0.3));
}

#[test]
pub fn test_sint8_float_qdq() {
let ground = Tensor::randn::<f32>(shape![64, 64], Device::CPU);

// Old api
let quantizer = Quantizer::new(Quantization::SInt8);
let q1 = quantizer.sint8_quantize(ground.deep_clone());
let dq1 = quantizer.sint8_dequantize(q1.deep_clone());

// New api
let q2 = quantize::<Q8_0F>(&ground);
let dq2 = dequantize(q2.deep_clone());

let q1_raw = unsafe { q1.deep_clone().into_bytes().unwrap() };
let q2_raw = unsafe { q2.deep_clone().into_bytes().unwrap() };
assert_eq!(q1_raw, q2_raw);

dq1.all_close(&dq2, 1e-3, 1e-3).unwrap();
}
}
1 change: 0 additions & 1 deletion crates/ratchet-nn/src/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ mod tests {
use hf_hub::api::sync::Api;
use proptest::arbitrary::Arbitrary;
use proptest::strategy::{BoxedStrategy, Just, Strategy};
use ratchet::{Quantization, Quantizer};
use ratchet_loader::gguf::gguf::Header;
use test_strategy::proptest;
use tokenizers::Tokenizer;
Expand Down

0 comments on commit b85ade6

Please sign in to comment.