diff --git a/crates/ratchet-core/src/cpu/gemm.rs b/crates/ratchet-core/src/cpu/gemm.rs index a0d24021..b7286ea0 100644 --- a/crates/ratchet-core/src/cpu/gemm.rs +++ b/crates/ratchet-core/src/cpu/gemm.rs @@ -1,10 +1,11 @@ use crate::{ - cpu_store_result, CPUOperation, Matmul, MatmulSpec, OperationError, Shape, Strides, Tensor, - TensorDType, + cpu_store_result, CPUOperation, DType, InvariantError, Matmul, MatmulSpec, OperationError, + Shape, Strides, Tensor, TensorDType, }; use anyhow::{anyhow, Result}; use core::str::FromStr; use gemm::{gemm, Parallelism}; +use half::{bf16, f16}; use std::num::NonZeroUsize; fn get_num_threads() -> NonZeroUsize { @@ -74,8 +75,6 @@ fn gemm_impl( let rhs_strides = rhs_strides.to_vec(); let rank = lhs_shape.rank(); - println!("lhs_strides: {lhs_strides:?}, rhs_strides: {rhs_strides:?}, rank: {rank}"); - let lhs_cs = lhs_strides[rank - 1]; let lhs_rs = lhs_strides[rank - 2]; @@ -98,11 +97,6 @@ fn gemm_impl( let mut dst = vec![T::zero(); b * m * n]; - println!("b: {b}, m: {m}, n: {n}, k: {k}"); - println!("dst_cs: {dst_cs}, dst_rs: {dst_rs}, dst_skip: {dst_skip}"); - println!("lhs_cs: {lhs_cs}, lhs_rs: {lhs_rs}, lhs_skip: {lhs_skip}"); - println!("rhs_cs: {rhs_cs}, rhs_rs: {rhs_rs}, rhs_skip: {rhs_skip}"); - for step in 0..b { let lhs_p = &lhs[step * lhs_skip..]; let rhs_p = &rhs[step * rhs_skip..]; @@ -135,28 +129,34 @@ fn gemm_impl( } impl CPUOperation for Matmul { - fn apply(&self, dst_tensor: Tensor) -> Result { + fn apply(&self, dst: Tensor) -> Result { + fn run_gemm( + spec: MatmulSpec, + lhs: &Tensor, + rhs: &Tensor, + dst: &Tensor, + ) -> Result<(), OperationError> { + let lhs = lhs.to_vec::()?; + let rhs = rhs.to_vec::()?; + + let result = if spec.trans_dst() { + gemm_impl::(spec, &rhs, &lhs)? + } else { + gemm_impl::(spec, &lhs, &rhs)? + }; + cpu_store_result(dst, &result); + Ok(()) + } let spec = self.compute_spec(); - let Matmul { - lhs, - rhs, - bias, - trans_lhs, - trans_rhs, - trans_dst, - } = self; - - let lhs = lhs.to_vec::()?; - let rhs = rhs.to_vec::()?; - - let result = if spec.trans_dst() { - gemm_impl::(spec, &rhs, &lhs)? - } else { - gemm_impl::(spec, &lhs, &rhs)? - }; + let Matmul { lhs, rhs, .. } = self; - cpu_store_result(&dst_tensor, &result); - Ok(dst_tensor) + match self.lhs.dt() { + DType::F32 => run_gemm::(spec, lhs, rhs, &dst), + DType::F16 => run_gemm::(spec, lhs, rhs, &dst), + DType::BF16 => run_gemm::(spec, lhs, rhs, &dst), + dtype => Err(InvariantError::UnsupportedDType(dtype))?, + }?; + Ok(dst) } }