Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/cpu gemm #247

Merged
merged 18 commits into from
Aug 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions crates/ratchet-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ rand = ["dep:rand", "dep:rand_distr"]
plotting = ["dep:dot3", "dep:tempfile"]
testing = ["dep:npyz", "dep:ndarray"]
pyo3 = ["dep:pyo3", "dep:numpy", "dep:regex"]
debug = [] #dump every node
debug = [] #dump every node

[dependencies]
ratchet-macros = { path = "../ratchet-macros"}
ratchet-macros = { path = "../ratchet-macros" }
inline-wgsl = { git = "https://github.com/FL33TW00D/inline-wgsl.git", branch = "master" }
wgpu = { workspace = true }
bytemuck = { workspace = true }
Expand All @@ -23,15 +23,17 @@ num-traits = { workspace = true }
log = { workspace = true }
thiserror = { workspace = true }
serde = { workspace = true, features = ["derive"] }
anyhow.workspace = true
anyhow.workspace = true

rustc-hash = { workspace = true }
slotmap = { workspace = true }
parking_lot = { workspace = true }
smallvec = { workspace = true }
encase = { workspace = true, features = ["smallvec", "glam"] }
pollster = { workspace = true }
getrandom = { workspace = true, features = ["js"] } # Needed for wasm support in `num` trait
getrandom = { workspace = true, features = [
"js",
] } # Needed for wasm support in `num` trait
num = { workspace = true }
rand_distr = { workspace = true, optional = true }
rand = { workspace = true, optional = true }
Expand All @@ -50,17 +52,18 @@ tempfile = { workspace = true, optional = true }
tabled = { workspace = true, optional = true }
itertools = { workspace = true, optional = true }

pyo3 = { workspace = true, features = ["auto-initialize"], optional = true }
pyo3 = { workspace = true, features = ["auto-initialize"], optional = true }
regex = { workspace = true, optional = true }
numpy = { workspace = true, optional = true, features=["half"]}
numpy = { workspace = true, optional = true, features = ["half"] }
gemm = { version = "0.18.0", features = ["nightly", "wasm-simd128-enable"] }

[target.'cfg(target_arch = "wasm32")'.dependencies]
wasm-bindgen.workspace = true
wasm-bindgen.workspace = true
futures-intrusive.workspace = true
wasm-bindgen-futures.workspace = true
wasm-bindgen-futures.workspace = true

async-trait = "0.1.77"
smallvec = { workspace = true , features = ["serde"] }
smallvec = { workspace = true, features = ["serde"] }

[dev-dependencies]
env_logger = { workspace = true }
Expand Down
162 changes: 162 additions & 0 deletions crates/ratchet-core/src/cpu/gemm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
use crate::{
cpu_store_result, CPUOperation, DType, InvariantError, Matmul, MatmulSpec, OperationError,
Shape, Strides, Tensor, TensorDType,
};
use anyhow::{anyhow, Result};
use core::str::FromStr;
use gemm::{gemm, Parallelism};
use half::{bf16, f16};
use std::num::NonZeroUsize;

fn get_num_threads() -> NonZeroUsize {
// Respond to the same environment variable as rayon.
match std::env::var("RAYON_NUM_THREADS")
.ok()
.and_then(|s| usize::from_str(&s).ok())
{
Some(x) if x > 0 => NonZeroUsize::new(x).unwrap(),
Some(_) | None => std::thread::available_parallelism()
.unwrap_or_else(|_| NonZeroUsize::new(1usize).unwrap()),
}
}

fn get_parallelism() -> Parallelism {
match get_num_threads().get() {
1 => Parallelism::None,
n => Parallelism::Rayon(n),
}
}

fn calculate_skips(
lhs_shape: &Shape,
lhs_strides: &[isize],
rhs_shape: &Shape,
rhs_strides: &[isize],
rank: usize,
m: usize,
n: usize,
k: usize,
) -> Result<(usize, usize)> {
let lhs_skip: usize = match lhs_strides[..rank - 2] {
[s1, stride] if s1 == stride * lhs_shape[1] as isize => stride as usize,
[_, stride] if lhs_shape[0] == 1 => stride as usize,
[stride, _] if lhs_shape[1] == 1 => stride as usize,
[stride] => stride as usize,
[] => m * k,
_ => Err(anyhow!("non-contiguous lhs"))?,
};
let rhs_skip: usize = match rhs_strides[..rank - 2] {
[s1, stride] if s1 == stride * rhs_shape[1] as isize => stride as usize,
[_, stride] if rhs_shape[0] == 1 => stride as usize,
[stride, _] if rhs_shape[1] == 1 => stride as usize,
[stride] => stride as usize,
[] => n * k,
_ => Err(anyhow!("non-contiguous rhs"))?,
};
Ok((lhs_skip, rhs_skip))
}

fn gemm_impl<T: TensorDType>(
spec: MatmulSpec,
lhs: &[T],
rhs: &[T],
) -> Result<Vec<T>, OperationError> {
let lhs_shape = spec.lhs_shape();
let rhs_shape = spec.rhs_shape();
let lhs_strides = spec.lhs_strides();
let rhs_strides = spec.rhs_strides();
let dst_strides = spec.dst_strides();
let b = spec.stacks();
let m = spec.m();
let n = spec.n();
let k = spec.k();

let lhs_strides = lhs_strides.to_vec();
let rhs_strides = rhs_strides.to_vec();
let rank = lhs_shape.rank();

let lhs_cs = lhs_strides[rank - 1];
let lhs_rs = lhs_strides[rank - 2];

let rhs_cs = rhs_strides[rank - 1];
let rhs_rs = rhs_strides[rank - 2];

let (lhs_skip, rhs_skip) = calculate_skips(
lhs_shape,
&lhs_strides,
rhs_shape,
&rhs_strides,
rank,
m,
n,
k,
)?;
let dst_skip: usize = m * n;
let dst_rs = dst_strides[0];
let dst_cs = dst_strides[1];

let mut dst = vec![T::zero(); b * m * n];

for step in 0..b {
let lhs_p = &lhs[step * lhs_skip..];
let rhs_p = &rhs[step * rhs_skip..];
let dst_p = &mut dst[step * dst_skip..];
unsafe {
gemm(
m,
n,
k,
dst_p.as_mut_ptr(),
dst_cs,
dst_rs,
false,
lhs_p.as_ptr(),
lhs_cs,
lhs_rs,
rhs_p.as_ptr(),
rhs_cs,
rhs_rs,
T::zero(),
T::one(),
false,
false,
false,
get_parallelism(),
)
}
}
Ok(dst)
}

impl CPUOperation for Matmul {
fn apply(&self, dst: Tensor) -> Result<Tensor, OperationError> {
fn run_gemm<T: TensorDType>(
spec: MatmulSpec,
lhs: &Tensor,
rhs: &Tensor,
dst: &Tensor,
) -> Result<(), OperationError> {
let lhs = lhs.to_vec::<T>()?;
let rhs = rhs.to_vec::<T>()?;

let result = if spec.trans_dst() {
gemm_impl::<T>(spec, &rhs, &lhs)?
} else {
gemm_impl::<T>(spec, &lhs, &rhs)?
};
cpu_store_result(dst, &result);
Ok(())
}
let spec = self.compute_spec();

let Matmul { lhs, rhs, .. } = self;

match self.lhs.dt() {
DType::F32 => run_gemm::<f32>(spec, lhs, rhs, &dst),
DType::F16 => run_gemm::<f16>(spec, lhs, rhs, &dst),
DType::BF16 => run_gemm::<bf16>(spec, lhs, rhs, &dst),
dtype => Err(InvariantError::UnsupportedDType(dtype))?,
}?;
Ok(dst)
}
}
8 changes: 5 additions & 3 deletions crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub mod gemm;

use crate::{
Binary, BinaryOp, CPUBuffer, CPUOperation, Cast, DType, IndexSelect, InvariantError, OpGuards,
Operation, OperationError, RVec, Storage, StorageView, Tensor, TensorDType, Unary, UnaryOp,
Expand Down Expand Up @@ -205,8 +207,8 @@ fn index_select<T: TensorDType>(
for left_i in 0..left_len {
let start_src_idx = left_i * right_len * src_dim;
let start_dst_idx = left_i * right_len * n_ids;
for i in 0..n_ids {
let src_idx = start_src_idx + indices[i] as usize * right_len;
for (i, idx) in indices.iter().enumerate().take(n_ids) {
let src_idx = start_src_idx + *idx as usize * right_len;
let dst_idx = start_dst_idx + i * right_len;
result[dst_idx..dst_idx + right_len]
.copy_from_slice(&src[src_idx..src_idx + right_len]);
Expand All @@ -232,7 +234,7 @@ fn direct_cast<T: TensorDType, U: TensorDType>(
let input = input.to_vec::<T>()?;
let result =
bytemuck::try_cast_slice::<T, U>(&input).map_err(|_| anyhow!("Failed direct cast"))?;
cpu_store_result(dst, &result);
cpu_store_result(dst, result);
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/gpu/wgsl/kernel_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl WgslKernelBuilder {
}

fn init_main(&mut self) {
self.main.write(&format!("{}\n", self.workgroup_size));
self.main.write(format!("{}\n", self.workgroup_size));
self.main.write("fn main(\n");
for (b, builtin) in self.builtins.iter().enumerate() {
let mut builtin = builtin.render();
Expand Down
5 changes: 2 additions & 3 deletions crates/ratchet-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ use crate::gpu::{
PoolError, WgpuDevice,
};
use crate::{
ops::*, rvec, CPUBuffer, CompiledOp, InvariantError, Kernel, KernelBuildError, KernelMetadata,
KernelModuleDesc, RVec, Storage, StorageView, Tensor, WgslFragment, WorkgroupSize,
ops::*, rvec, CompiledOp, InvariantError, Kernel, KernelBuildError, KernelMetadata,
KernelModuleDesc, RVec, StorageView, Tensor, WgslFragment, WorkgroupSize,
};
use bytemuck::NoUninit;
use std::borrow::Cow;
use std::fmt::Debug;

Expand Down
9 changes: 4 additions & 5 deletions crates/ratchet-core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ use inline_wgsl::wgsl;
use ratchet_macros::WgslMetadata;

use crate::{
binary_apply_inplace, cpu_store_result,
gpu::{dtype::WgslDType, BindGroupLayoutDescriptor},
rvec, Array, BindingMode, BuiltIn, CPUOperation, DType, GPUOperation, InvariantError, Kernel,
KernelElement, KernelRenderable, KernelSource, OpGuards, Operation, OperationError, RVec,
Scalar, Shape, StorageView, Strides, Tensor, Vec2, Vec4, WgslKernelBuilder, WgslPrimitive,
WorkgroupSize, Workload,
rvec, Array, BindingMode, BuiltIn, DType, GPUOperation, InvariantError, Kernel, KernelElement,
KernelRenderable, KernelSource, OpGuards, Operation, OperationError, RVec, Scalar, Shape,
StorageView, Strides, Tensor, Vec2, Vec4, WgslKernelBuilder, WgslPrimitive, WorkgroupSize,
Workload,
};
#[cfg(test)]
use test_strategy::Arbitrary;
Expand Down
Loading
Loading