From 35eca06d6924854a89bbb72a12702975d9d6fdd9 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 23 Aug 2024 10:52:32 +0200 Subject: [PATCH] fix: gemm works on both cpu and gpu --- crates/ratchet-core/src/ops/matmul/gemm.rs | 17 ++----- crates/ratchet-core/src/ops/matmul/mod.rs | 46 ++++++++++--------- .../src/ops/matmul/subgroup_gemv.rs | 7 +-- .../src/ops/matmul/workgroup_gemv.rs | 13 +----- 4 files changed, 30 insertions(+), 53 deletions(-) diff --git a/crates/ratchet-core/src/ops/matmul/gemm.rs b/crates/ratchet-core/src/ops/matmul/gemm.rs index cbc2d2ec..b94fc980 100644 --- a/crates/ratchet-core/src/ops/matmul/gemm.rs +++ b/crates/ratchet-core/src/ops/matmul/gemm.rs @@ -181,17 +181,6 @@ impl Kernel for GEMM { let dim_rhs_outer = spec.dim_rhs_outer() as i32; let dim_inner = spec.dim_inner() as i32; - println!("GEMMMeta"); - println!("lhs_shape: {:?}", lhs_shape); - println!("lhs_strides: {:?}", lhs_strides); - println!("rhs_shape: {:?}", rhs_shape); - println!("rhs_strides: {:?}", rhs_strides); - println!("dst_shape: {:?}", dst_shape); - println!("dst_strides: {:?}", dst_strides); - println!("dim_lhs_outer: {:?}", spec.m()); - println!("dim_rhs_outer: {:?}", spec.n()); - println!("dim_inner: {:?}", spec.k()); - Ok(GEMMMeta { lhs_shape: lhs_shape.into(), lhs_strides: lhs_strides.into(), @@ -199,9 +188,9 @@ impl Kernel for GEMM { rhs_strides: rhs_strides.into(), dst_shape: dst_shape.into(), dst_strides: dst_strides.into(), - dim_lhs_outer: spec.m() as i32, - dim_rhs_outer: spec.n() as i32, - dim_inner: spec.k() as i32, + dim_lhs_outer, + dim_rhs_outer, + dim_inner, }) } diff --git a/crates/ratchet-core/src/ops/matmul/mod.rs b/crates/ratchet-core/src/ops/matmul/mod.rs index b822cba6..d46be97c 100644 --- a/crates/ratchet-core/src/ops/matmul/mod.rs +++ b/crates/ratchet-core/src/ops/matmul/mod.rs @@ -12,9 +12,9 @@ use std::{cmp::Ordering, mem}; use crate::{ gpu::{BindGroupLayoutDescriptor, CpuUniform}, - rvec, DType, GPUOperation, Kernel, KernelElement, KernelKey, KernelMetadata, KernelRenderable, - KernelSource, OpGuards, Operation, OperationError, RVec, Shape, StorageView, Strides, Tensor, - WorkgroupSize, Workload, Q4_KF, Q4_KH, Q8_0F, Q8_0H, + rvec, DType, Device, GPUOperation, Kernel, KernelElement, KernelKey, KernelMetadata, + KernelRenderable, KernelSource, OpGuards, Operation, OperationError, RVec, Shape, StorageView, + Strides, Tensor, WorkgroupSize, Workload, Q4_KF, Q4_KH, Q8_0F, Q8_0H, }; //https://link.springer.com/chapter/10.1007/978-3-642-29737-3_42 @@ -132,24 +132,28 @@ impl MatmulSpec { let mut rhs_strides = Strides::from(&rhs_shape); let dst_strides = Strides::from(&dst_shape); - // The (a b)T => bT aT rule means that if we have to transpose dst we can simply transpose the inputs and swap them. - // However two transposes cancel each other out, in which case we can just skip transposing the input altogether. - // This is just the xor operator (^). - if trans_lhs ^ trans_dst { - lhs_shape.transpose(); - lhs_strides.transpose(); - } - if trans_rhs ^ trans_dst { - rhs_shape.transpose(); - rhs_strides.transpose(); - } - if trans_dst { - // (a b)T => bT aT - // aT bT has already been applied correctly above, so we can just swap. - mem::swap(&mut lhs_shape, &mut rhs_shape); - // strides and transposes must follow their shapes - mem::swap(&mut lhs_strides, &mut rhs_strides); - mem::swap(&mut trans_lhs, &mut trans_rhs); + let is_cpu = matches!(LHS.device(), Device::CPU); + + if is_cpu { + // The (a b)T => bT aT rule means that if we have to transpose dst we can simply transpose the inputs and swap them. + // However two transposes cancel each other out, in which case we can just skip transposing the input altogether. + // This is just the xor operator (^). + if trans_lhs ^ trans_dst { + lhs_shape.transpose(); + lhs_strides.transpose(); + } + if trans_rhs ^ trans_dst { + rhs_shape.transpose(); + rhs_strides.transpose(); + } + if trans_dst { + // (a b)T => bT aT + // aT bT has already been applied correctly above, so we can just swap. + mem::swap(&mut lhs_shape, &mut rhs_shape); + // strides and transposes must follow their shapes + mem::swap(&mut lhs_strides, &mut rhs_strides); + mem::swap(&mut trans_lhs, &mut trans_rhs); + } } log::debug!( diff --git a/crates/ratchet-core/src/ops/matmul/subgroup_gemv.rs b/crates/ratchet-core/src/ops/matmul/subgroup_gemv.rs index 9609c9e9..0612c0da 100644 --- a/crates/ratchet-core/src/ops/matmul/subgroup_gemv.rs +++ b/crates/ratchet-core/src/ops/matmul/subgroup_gemv.rs @@ -353,14 +353,9 @@ impl Kernel for SubgroupGEMV { } fn metadata(&self, _: &Tensor, _: &KernelElement) -> Result { - println!( - "SubgroupGEMVMeta: OVL: {}, IVL: {}", - self.spec.new_dim_lhs_outer(), - self.spec.k() - ); Ok(SubgroupGEMVMeta { OVL: self.spec.new_dim_lhs_outer() as _, - IVL: self.spec.k() as _, + IVL: self.spec.new_dim_rhs_outer() as _, }) } diff --git a/crates/ratchet-core/src/ops/matmul/workgroup_gemv.rs b/crates/ratchet-core/src/ops/matmul/workgroup_gemv.rs index d5898d44..2fc37e6a 100644 --- a/crates/ratchet-core/src/ops/matmul/workgroup_gemv.rs +++ b/crates/ratchet-core/src/ops/matmul/workgroup_gemv.rs @@ -101,7 +101,7 @@ impl Kernel for WorkgroupGEMV { fn metadata(&self, _: &Tensor, _: &KernelElement) -> Result { let spec = &self.spec; - let mut lhs_shape = spec.raw_lhs_shape().clone(); + let mut lhs_shape = spec.lhs_shape().clone(); lhs_shape.insert(0, spec.lhs_stack()); let lhs_strides = Strides::from(&lhs_shape); @@ -117,17 +117,6 @@ impl Kernel for WorkgroupGEMV { let dim_rhs_outer = spec.dim_rhs_outer() as i32; let dim_inner = spec.dim_inner() as i32; - println!("WorkgroupGEMVMeta"); - println!("lhs_shape: {:?}", lhs_shape); - println!("lhs_strides: {:?}", lhs_strides); - println!("rhs_shape: {:?}", rhs_shape); - println!("rhs_strides: {:?}", rhs_strides); - println!("dst_shape: {:?}", dst_shape); - println!("dst_strides: {:?}", dst_strides); - println!("dim_lhs_outer: {:?}", spec.m()); - println!("dim_rhs_outer: {:?}", spec.n()); - println!("dim_inner: {:?}", spec.k()); - Ok(WorkgroupGEMVMeta { lhs_shape: lhs_shape.into(), lhs_strides: lhs_strides.into(),