Skip to content

Commit

Permalink
feat: f16 and bf16 gemm support
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Aug 24, 2024
1 parent 4cda652 commit e9bc238
Showing 1 changed file with 29 additions and 29 deletions.
58 changes: 29 additions & 29 deletions crates/ratchet-core/src/cpu/gemm.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::{
cpu_store_result, CPUOperation, Matmul, MatmulSpec, OperationError, Shape, Strides, Tensor,
TensorDType,
cpu_store_result, CPUOperation, DType, InvariantError, Matmul, MatmulSpec, OperationError,
Shape, Strides, Tensor, TensorDType,
};
use anyhow::{anyhow, Result};
use core::str::FromStr;
use gemm::{gemm, Parallelism};
use half::{bf16, f16};
use std::num::NonZeroUsize;

fn get_num_threads() -> NonZeroUsize {
Expand Down Expand Up @@ -74,8 +75,6 @@ fn gemm_impl<T: TensorDType>(
let rhs_strides = rhs_strides.to_vec();
let rank = lhs_shape.rank();

println!("lhs_strides: {lhs_strides:?}, rhs_strides: {rhs_strides:?}, rank: {rank}");

let lhs_cs = lhs_strides[rank - 1];
let lhs_rs = lhs_strides[rank - 2];

Expand All @@ -98,11 +97,6 @@ fn gemm_impl<T: TensorDType>(

let mut dst = vec![T::zero(); b * m * n];

println!("b: {b}, m: {m}, n: {n}, k: {k}");
println!("dst_cs: {dst_cs}, dst_rs: {dst_rs}, dst_skip: {dst_skip}");
println!("lhs_cs: {lhs_cs}, lhs_rs: {lhs_rs}, lhs_skip: {lhs_skip}");
println!("rhs_cs: {rhs_cs}, rhs_rs: {rhs_rs}, rhs_skip: {rhs_skip}");

for step in 0..b {
let lhs_p = &lhs[step * lhs_skip..];
let rhs_p = &rhs[step * rhs_skip..];
Expand Down Expand Up @@ -135,28 +129,34 @@ fn gemm_impl<T: TensorDType>(
}

impl CPUOperation for Matmul {
fn apply(&self, dst_tensor: Tensor) -> Result<Tensor, OperationError> {
fn apply(&self, dst: Tensor) -> Result<Tensor, OperationError> {
fn run_gemm<T: TensorDType>(
spec: MatmulSpec,
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
) -> Result<(), OperationError> {
let lhs = lhs.to_vec::<T>()?;
let rhs = rhs.to_vec::<T>()?;

let result = if spec.trans_dst() {
gemm_impl::<T>(spec, &rhs, &lhs)?
} else {
gemm_impl::<T>(spec, &lhs, &rhs)?
};
cpu_store_result(dst, &result);
Ok(())
}
let spec = self.compute_spec();

let Matmul {
lhs,
rhs,
bias,
trans_lhs,
trans_rhs,
trans_dst,
} = self;

let lhs = lhs.to_vec::<f32>()?;
let rhs = rhs.to_vec::<f32>()?;

let result = if spec.trans_dst() {
gemm_impl::<f32>(spec, &rhs, &lhs)?
} else {
gemm_impl::<f32>(spec, &lhs, &rhs)?
};
let Matmul { lhs, rhs, .. } = self;

cpu_store_result(&dst_tensor, &result);
Ok(dst_tensor)
match self.lhs.dt() {
DType::F32 => run_gemm::<f32>(spec, lhs, rhs, &dst),
DType::F16 => run_gemm::<f16>(spec, lhs, rhs, &dst),
DType::BF16 => run_gemm::<bf16>(spec, lhs, rhs, &dst),
dtype => Err(InvariantError::UnsupportedDType(dtype))?,
}?;
Ok(dst)
}
}

0 comments on commit e9bc238

Please sign in to comment.