Skip to content

Commit

Permalink
fix: gemm works on both cpu and gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Aug 23, 2024
1 parent f4e4173 commit 35eca06
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 53 deletions.
17 changes: 3 additions & 14 deletions crates/ratchet-core/src/ops/matmul/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,27 +181,16 @@ impl Kernel for GEMM {
let dim_rhs_outer = spec.dim_rhs_outer() as i32;
let dim_inner = spec.dim_inner() as i32;

println!("GEMMMeta");
println!("lhs_shape: {:?}", lhs_shape);
println!("lhs_strides: {:?}", lhs_strides);
println!("rhs_shape: {:?}", rhs_shape);
println!("rhs_strides: {:?}", rhs_strides);
println!("dst_shape: {:?}", dst_shape);
println!("dst_strides: {:?}", dst_strides);
println!("dim_lhs_outer: {:?}", spec.m());
println!("dim_rhs_outer: {:?}", spec.n());
println!("dim_inner: {:?}", spec.k());

Ok(GEMMMeta {
lhs_shape: lhs_shape.into(),
lhs_strides: lhs_strides.into(),
rhs_shape: rhs_shape.into(),
rhs_strides: rhs_strides.into(),
dst_shape: dst_shape.into(),
dst_strides: dst_strides.into(),
dim_lhs_outer: spec.m() as i32,
dim_rhs_outer: spec.n() as i32,
dim_inner: spec.k() as i32,
dim_lhs_outer,
dim_rhs_outer,
dim_inner,
})
}

Expand Down
46 changes: 25 additions & 21 deletions crates/ratchet-core/src/ops/matmul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ use std::{cmp::Ordering, mem};

use crate::{
gpu::{BindGroupLayoutDescriptor, CpuUniform},
rvec, DType, GPUOperation, Kernel, KernelElement, KernelKey, KernelMetadata, KernelRenderable,
KernelSource, OpGuards, Operation, OperationError, RVec, Shape, StorageView, Strides, Tensor,
WorkgroupSize, Workload, Q4_KF, Q4_KH, Q8_0F, Q8_0H,
rvec, DType, Device, GPUOperation, Kernel, KernelElement, KernelKey, KernelMetadata,
KernelRenderable, KernelSource, OpGuards, Operation, OperationError, RVec, Shape, StorageView,
Strides, Tensor, WorkgroupSize, Workload, Q4_KF, Q4_KH, Q8_0F, Q8_0H,
};

//https://link.springer.com/chapter/10.1007/978-3-642-29737-3_42
Expand Down Expand Up @@ -132,24 +132,28 @@ impl MatmulSpec {
let mut rhs_strides = Strides::from(&rhs_shape);
let dst_strides = Strides::from(&dst_shape);

// The (a b)T => bT aT rule means that if we have to transpose dst we can simply transpose the inputs and swap them.
// However two transposes cancel each other out, in which case we can just skip transposing the input altogether.
// This is just the xor operator (^).
if trans_lhs ^ trans_dst {
lhs_shape.transpose();
lhs_strides.transpose();
}
if trans_rhs ^ trans_dst {
rhs_shape.transpose();
rhs_strides.transpose();
}
if trans_dst {
// (a b)T => bT aT
// aT bT has already been applied correctly above, so we can just swap.
mem::swap(&mut lhs_shape, &mut rhs_shape);
// strides and transposes must follow their shapes
mem::swap(&mut lhs_strides, &mut rhs_strides);
mem::swap(&mut trans_lhs, &mut trans_rhs);
let is_cpu = matches!(LHS.device(), Device::CPU);

if is_cpu {
// The (a b)T => bT aT rule means that if we have to transpose dst we can simply transpose the inputs and swap them.
// However two transposes cancel each other out, in which case we can just skip transposing the input altogether.
// This is just the xor operator (^).
if trans_lhs ^ trans_dst {
lhs_shape.transpose();
lhs_strides.transpose();
}
if trans_rhs ^ trans_dst {
rhs_shape.transpose();
rhs_strides.transpose();
}
if trans_dst {
// (a b)T => bT aT
// aT bT has already been applied correctly above, so we can just swap.
mem::swap(&mut lhs_shape, &mut rhs_shape);
// strides and transposes must follow their shapes
mem::swap(&mut lhs_strides, &mut rhs_strides);
mem::swap(&mut trans_lhs, &mut trans_rhs);
}
}

log::debug!(
Expand Down
7 changes: 1 addition & 6 deletions crates/ratchet-core/src/ops/matmul/subgroup_gemv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,14 +353,9 @@ impl Kernel for SubgroupGEMV {
}

fn metadata(&self, _: &Tensor, _: &KernelElement) -> Result<Self::Metadata, OperationError> {
println!(
"SubgroupGEMVMeta: OVL: {}, IVL: {}",
self.spec.new_dim_lhs_outer(),
self.spec.k()
);
Ok(SubgroupGEMVMeta {
OVL: self.spec.new_dim_lhs_outer() as _,
IVL: self.spec.k() as _,
IVL: self.spec.new_dim_rhs_outer() as _,
})
}

Expand Down
13 changes: 1 addition & 12 deletions crates/ratchet-core/src/ops/matmul/workgroup_gemv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl Kernel for WorkgroupGEMV {

fn metadata(&self, _: &Tensor, _: &KernelElement) -> Result<Self::Metadata, OperationError> {
let spec = &self.spec;
let mut lhs_shape = spec.raw_lhs_shape().clone();
let mut lhs_shape = spec.lhs_shape().clone();
lhs_shape.insert(0, spec.lhs_stack());
let lhs_strides = Strides::from(&lhs_shape);

Expand All @@ -117,17 +117,6 @@ impl Kernel for WorkgroupGEMV {
let dim_rhs_outer = spec.dim_rhs_outer() as i32;
let dim_inner = spec.dim_inner() as i32;

println!("WorkgroupGEMVMeta");
println!("lhs_shape: {:?}", lhs_shape);
println!("lhs_strides: {:?}", lhs_strides);
println!("rhs_shape: {:?}", rhs_shape);
println!("rhs_strides: {:?}", rhs_strides);
println!("dst_shape: {:?}", dst_shape);
println!("dst_strides: {:?}", dst_strides);
println!("dim_lhs_outer: {:?}", spec.m());
println!("dim_rhs_outer: {:?}", spec.n());
println!("dim_inner: {:?}", spec.k());

Ok(WorkgroupGEMVMeta {
lhs_shape: lhs_shape.into(),
lhs_strides: lhs_strides.into(),
Expand Down

0 comments on commit 35eca06

Please sign in to comment.