Skip to content

Commit

Permalink
stash for debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Aug 28, 2024
1 parent 51f327b commit 500eb1a
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 28 deletions.
28 changes: 25 additions & 3 deletions crates/ratchet-core/src/dtype/blocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
///
/// We closely follow the memory layout of the original GGUF implementation,
/// but often need 2 variants of each block type for devices that don't support f16.
use crate::{rvec, Align, BufferSegment, RVec, TensorDType};
use crate::{rvec, Align, BufferSegment, DType, RVec, TensorDType};
use derive_new::new;
use half::f16;
use num_traits::{AsPrimitive, Float, FromPrimitive};
Expand Down Expand Up @@ -169,24 +169,46 @@ pub trait Quantized {
type FP: TensorDType + Float + AsPrimitive<i32> + FromPrimitive + Copy + PartialEq;
const PACK_SIZE: usize;
const GROUP_SIZE: usize;

const LSHIFT: usize = Self::GROUP_SIZE / Self::PACK_SIZE;
const MASK: i32 = (1 << Self::LSHIFT) - 1;
const RSHIFT: usize = Self::GROUP_SIZE - Self::LSHIFT;

fn dt() -> DType;
}
impl Quantized for Q8_0F {
type FP = f32;
const PACK_SIZE: usize = 4;
const GROUP_SIZE: usize = 32;

fn dt() -> DType {
DType::Q8_0F(Q8_0F::default())
}
}
impl Quantized for Q8_0H {
type FP = f16;
const PACK_SIZE: usize = 4;
const GROUP_SIZE: usize = 32;

fn dt() -> DType {
DType::Q8_0H(Q8_0H::default())
}
}
impl Quantized for Q4_KF {
type FP = f32;
const PACK_SIZE: usize = 8;
const GROUP_SIZE: usize = 8;
const GROUP_SIZE: usize = 32;

fn dt() -> DType {
DType::Q4_KF(Q4_KF::default())
}
}
impl Quantized for Q4_KH {
type FP = f16;
const PACK_SIZE: usize = 8;
const GROUP_SIZE: usize = 8;
const GROUP_SIZE: usize = 32;

fn dt() -> DType {
DType::Q4_KH(Q4_KH::default())
}
}
143 changes: 118 additions & 25 deletions crates/ratchet-core/src/quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use num_traits::{AsPrimitive, Float, FromPrimitive, Zero};
use std::fmt::Debug;

use crate::{
dtype::Quantized, gpu::STORAGE_BUFFER_ALIGN, DType, Device, Tensor, TensorDType, Q8_0F,
dtype::Quantized, gpu::STORAGE_BUFFER_ALIGN, DType, Device, Tensor, Q4_KF, Q4_KH, Q8_0F, Q8_0H,
};

/// Quantizer
Expand All @@ -15,28 +15,30 @@ pub struct Quantizer {
format: Quantization,
}

fn quantize_inner<Q: Quantized>(matrix: &[Q::FP], elements: usize) -> Vec<u32> {
#[inline]
fn storage_align<T>(n: usize) -> usize {
let size_t = core::mem::size_of::<T>();
let nbytes = n * size_t;
let aligned = if nbytes % STORAGE_BUFFER_ALIGN != 0 {
nbytes + STORAGE_BUFFER_ALIGN - nbytes % STORAGE_BUFFER_ALIGN
} else {
nbytes
};
aligned / size_t
}

pub fn quantize_inner<Q: Quantized>(matrix: &[Q::FP], elements: usize) -> Vec<u32> {
println!("quantize_inner");
assert_eq!(elements % Q::PACK_SIZE, 0);
assert_eq!(elements % Q::GROUP_SIZE, 0);

let qmatrix_len = elements / Q::PACK_SIZE;
let amatrix_len = elements / Q::GROUP_SIZE;

//returns the aligned number of ELEMENTS
let aligner = |numel: usize, size_t: usize| -> usize {
let nbytes = numel * size_t;
let aligned = if nbytes % STORAGE_BUFFER_ALIGN != 0 {
nbytes + STORAGE_BUFFER_ALIGN - nbytes % STORAGE_BUFFER_ALIGN
} else {
nbytes
};
aligned / size_t
};

let mut quantized_matrix = vec![0u32; aligner(qmatrix_len, std::mem::size_of::<u32>())];
let mut absmax_matrix = vec![Q::FP::zero(); aligner(amatrix_len, std::mem::size_of::<Q::FP>())];

let mut quantized_matrix = vec![0u32; storage_align::<u32>(qmatrix_len)];
let mut absmax_matrix = vec![Q::FP::zero(); storage_align::<Q::FP>(amatrix_len)];
let mut block_absmax = Q::FP::neg_infinity();

for i in (0..elements).step_by(Q::PACK_SIZE) {
if i % Q::GROUP_SIZE == 0 {
let amax = matrix[i..i + Q::GROUP_SIZE]
Expand All @@ -47,7 +49,7 @@ fn quantize_inner<Q: Quantized>(matrix: &[Q::FP], elements: usize) -> Vec<u32> {
}
for j in 0..Q::PACK_SIZE {
let packed_value: i32 =
((matrix[i + j] / block_absmax).round().as_() & 0xFF) << (j * 8);
((matrix[i + j] / block_absmax).round().as_() & Q::MASK) << (j * Q::LSHIFT);
quantized_matrix[i / Q::PACK_SIZE] |= packed_value as u32;
}
absmax_matrix[i / Q::GROUP_SIZE] = block_absmax;
Expand All @@ -58,8 +60,8 @@ fn quantize_inner<Q: Quantized>(matrix: &[Q::FP], elements: usize) -> Vec<u32> {
}

pub fn quantize<Q: Quantized>(tensor: &Tensor) -> Tensor {
return match tensor.dt() {
DType::F32 => {
match (tensor.dt(), Q::dt()) {
(DType::F32, DType::Q8_0F(_)) => {
let matrix = tensor.to_vec::<Q::FP>().unwrap();
unsafe {
Tensor::from_quantized(
Expand All @@ -70,11 +72,45 @@ pub fn quantize<Q: Quantized>(tensor: &Tensor) -> Tensor {
)
}
}
dt => panic!("Unsupported dtype {dt}"),
};
(DType::F32, DType::Q4_KF(_)) => {
let matrix = tensor.to_vec::<Q::FP>().unwrap();
unsafe {
Tensor::from_quantized(
quantize_inner::<Q>(&matrix, tensor.shape().numel()),
DType::Q4_KF(Q4_KF::default()),
tensor.shape().clone(),
Device::CPU,
)
}
}
(DType::F16, DType::Q8_0H(_)) => {
let matrix = tensor.to_vec::<Q::FP>().unwrap();
unsafe {
Tensor::from_quantized(
quantize_inner::<Q>(&matrix, tensor.shape().numel()),
DType::Q8_0H(Q8_0H::default()),
tensor.shape().clone(),
Device::CPU,
)
}
}
(DType::F16, DType::Q4_KH(_)) => {
let matrix = tensor.to_vec::<Q::FP>().unwrap();
unsafe {
Tensor::from_quantized(
quantize_inner::<Q>(&matrix, tensor.shape().numel()),
DType::Q4_KH(Q4_KH::default()),
tensor.shape().clone(),
Device::CPU,
)
}
}
(dt, qdt) => panic!("Unsupported dtype combination {dt}, {qdt}"),
}
}

fn dequantize_inner<Q: Quantized>(quantized: &[u8], numel: usize) -> Vec<Q::FP> {
println!("dequantize_inner");
let num_q = numel / Q::PACK_SIZE;
let num_q_bytes = num_q * std::mem::size_of::<u32>();
let aligner = |numel: usize, size_t: usize| -> usize {
Expand All @@ -101,9 +137,11 @@ fn dequantize_inner<Q: Quantized>(quantized: &[u8], numel: usize) -> Vec<Q::FP>
let block_absmax = absmax_matrix[div_floor(i, Q::GROUP_SIZE)];
let packed_value = quantized_matrix[div_floor(i, Q::PACK_SIZE)] as i32;
for j in 0..Q::PACK_SIZE {
dequantized[i + j] =
Q::FP::from_i32((packed_value << (8 * (Q::PACK_SIZE - j - 1))) >> 24).unwrap()
* block_absmax;
dequantized[i + j] = Q::FP::from_i32(
(packed_value << (Q::LSHIFT * (Q::PACK_SIZE - j - 1))) >> Q::RSHIFT,
)
.unwrap()
* block_absmax;
}
}

Expand All @@ -119,6 +157,27 @@ pub fn dequantize(quantized: Tensor) -> Tensor {
let dequantized = dequantize_inner::<Q8_0F>(&raw_bytes, elements);
Tensor::from_data(&dequantized, original_shape, Device::CPU)
}
DType::Q4_KF(_) => {
let elements = quantized.shape().numel();
let original_shape = quantized.shape().clone();
let raw_bytes = unsafe { quantized.into_bytes().unwrap() };
let dequantized = dequantize_inner::<Q4_KF>(&raw_bytes, elements);
Tensor::from_data(&dequantized, original_shape, Device::CPU)
}
DType::Q8_0H(_) => {
let elements = quantized.shape().numel();
let original_shape = quantized.shape().clone();
let raw_bytes = unsafe { quantized.into_bytes().unwrap() };
let dequantized = dequantize_inner::<Q8_0H>(&raw_bytes, elements);
Tensor::from_data(&dequantized, original_shape, Device::CPU)
}
DType::Q4_KH(_) => {
let elements = quantized.shape().numel();
let original_shape = quantized.shape().clone();
let raw_bytes = unsafe { quantized.into_bytes().unwrap() };
let dequantized = dequantize_inner::<Q4_KH>(&raw_bytes, elements);
Tensor::from_data(&dequantized, original_shape, Device::CPU)
}
dt => panic!("Unsupported dtype {dt}"),
};
}
Expand Down Expand Up @@ -299,7 +358,10 @@ impl Quantization {

#[cfg(test)]
mod tests {
use crate::{dequantize, quantize, shape, Device, Quantization, Quantizer, Tensor, Q8_0F};
use crate::{
dequantize, quantize, quantize_inner, shape, Device, Quantization, Quantizer, Tensor,
Q4_KF, Q8_0F,
};

#[test]
pub fn test_sint8_qdq() {
Expand All @@ -317,7 +379,38 @@ mod tests {
let q1_raw = unsafe { q1.deep_clone().into_bytes().unwrap() };
let q2_raw = unsafe { q2.deep_clone().into_bytes().unwrap() };
assert_eq!(q1_raw, q2_raw);
if q1_raw == q2_raw {
println!("SInt8 quantization is correct");
}

dq1.all_close(&dq2, 1e-3, 1e-3).unwrap();
}

#[test]
pub fn test_sint4_qdq() {
let ground = Tensor::randn::<f32>(shape![64, 64], Device::CPU);

// Old api
let data = ground.to_vec::<f32>().unwrap();
let (q1, absmax) = Quantizer::sint4_quantize::<f32>(&data, 64, 64);
let dq1 = Quantizer::sint4_dequantize(&q1, absmax, 64, 64);

// New api
let q2 = quantize_inner::<Q4_KF>(&data, 64 * 64);
//let dq2 = dequantize(q2.deep_clone());

for (a, b) in q1.iter().zip(q2.iter()) {
if a != b {
println!("{} {}", a, b);
}
}
/*
let dq2_vec = dq2.to_vec::<f32>().unwrap();
for (a, b) in dq1.iter().zip(dq2_vec.iter()) {
if (a - b).abs() >= 1e-3 {
println!("{} {}", a, b);
}
}
*/
}
}

0 comments on commit 500eb1a

Please sign in to comment.