Skip to content

Commit

Permalink
Shared Metal storage buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Dec 28, 2023
1 parent 422fd32 commit ee17a48
Show file tree
Hide file tree
Showing 16 changed files with 428 additions and 244 deletions.
6 changes: 3 additions & 3 deletions src/compilers/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl Compiler for MatMul2DCompiler {
pub struct MatMul2D;

impl Operator for MatMul2D {
fn process(&self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
let (a_strides, b_strides) = (inp[0].1.strides(), inp[1].1.strides());
let a_data = inp[0]
Expand Down Expand Up @@ -199,7 +199,7 @@ pub struct BatchedMatMul2D;

// ABCxCD -> ABD
impl Operator for BatchedMatMul2D {
fn process(&self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let (a_shape, b_shape) = (inp[0].1.shape(), inp[1].1.shape());
let (a_strides, b_strides) = (inp[0].1.strides(), inp[1].1.strides());
let a_data = inp[0]
Expand Down Expand Up @@ -339,7 +339,7 @@ impl Compiler for UnaryFusionCompiler {
pub struct FusedUnary(Vec<fn(f32) -> f32>);

impl Operator for FusedUnary {
fn process(&self, mut inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
fn process(&mut self, mut inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let mut t = inp.pop().unwrap().0.cloned();
for a in t
.data
Expand Down
6 changes: 3 additions & 3 deletions src/compilers/metal/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl<T> MetalKernelForward for MetalSub<T> {
}

impl<T: MetalFloat> Operator for MetalSub<T> {
fn process(&self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
let command_buffer = self.1.new_command_buffer();
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
Expand Down Expand Up @@ -310,7 +310,7 @@ impl<T> MetalKernelForward for MetalEqual<T> {
}

impl<T: MetalFloat> Operator for MetalEqual<T> {
fn process(&self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
let command_buffer = self.1.new_command_buffer();
let inp_size = tensors[0].1.n_elements().to_usize().unwrap();
Expand Down Expand Up @@ -485,7 +485,7 @@ kernel void metal_gather(device float *inp [[buffer(0)]], device {} *weights [[b
}

impl<T: MetalFloat> Operator for MetalGather<T> {
fn process(&self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
// Setup buffers
let indexes = tensors[0]
Expand Down
55 changes: 46 additions & 9 deletions src/compilers/metal/command_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ impl Compiler for CommandBufferCompiler {
.unwrap()
.downcast::<MetalKernelWrapper>()
.unwrap();
*graph.graph.node_weight_mut(*node).unwrap() = Box::new(MetalKernelOperation {
*graph.graph.node_weight_mut(*node).unwrap() = Box::new(CommandBufferWrapper {
wrapper,
dev: dev.clone(),
buffer: buffer.clone(),
Expand Down Expand Up @@ -203,7 +203,7 @@ struct ExecuteMetalKernels {
}

impl Operator for ExecuteMetalKernels {
fn process(&self, _: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
fn process(&mut self, _: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
let buffer = unsafe { &mut *self.buffer.get() };
buffer.commit();
buffer.wait_until_completed();
Expand All @@ -212,22 +212,50 @@ impl Operator for ExecuteMetalKernels {
}
}

#[derive(LuminalEq)]
struct MetalKernelOperation {
#[derive(LuminalEq, Clone)]
struct CommandBufferWrapper {
wrapper: Box<MetalKernelWrapper>,
dev: Device,
buffer: Arc<UnsafeCell<CommandBuffer>>,
dyn_map: *const HashMap<char, usize>,
}

impl std::fmt::Debug for MetalKernelOperation {
impl std::fmt::Debug for CommandBufferWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalKernel({:?})", self.wrapper.0)
self.wrapper.0.fmt(f)
}
}

impl Operator for MetalKernelOperation {
fn process(&self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
impl MetalKernelForward for CommandBufferWrapper {
fn intermediate_buffer_sizes(
&self,
input_shapes: &[ShapeTracker],
) -> Vec<symbolic::BigExpression> {
self.wrapper.0.intermediate_buffer_sizes(input_shapes)
}
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<symbolic::BigExpression> {
self.wrapper.0.output_buffer_sizes(input_shapes)
}
fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
dev: &Device,
_: &metal_rs::CommandBufferRef,
intermediate_buffers: &[&Buffer],
output_buffers: &[&Buffer],
) {
self.wrapper.0.metal_forward(
inputs,
&dev,
unsafe { &*self.buffer.get() },
intermediate_buffers,
output_buffers,
);
}
}

impl Operator for CommandBufferWrapper {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
// For now let's allocate the required buffers here
let inp_shapes = inp.iter().map(|(_, s)| *s).collect::<Vec<_>>();
let intermediate_buffers = self
Expand Down Expand Up @@ -256,7 +284,7 @@ impl Operator for MetalKernelOperation {
})
.collect::<Vec<_>>();
let output_buffers_ref = output_buffers.iter().collect::<Vec<_>>();
self.wrapper.0.metal_forward(
self.metal_forward(
&inp.iter()
.map(|(t, sh)| {
(
Expand All @@ -275,6 +303,15 @@ impl Operator for MetalKernelOperation {
.map(|b| Tensor { data: Box::new(b) })
.collect()
}

fn custom(&self, key: &str) -> Option<Box<dyn std::any::Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
None
}
}

#[test]
Expand Down
16 changes: 6 additions & 10 deletions src/compilers/metal/fp16/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static constant constexpr const int TN = {TN};
kernel void kernel_vecmat(
const device half* in_vec [[buffer(0)]],
const device half* mat [[buffer(1)]],
device half* out_vec [[buffer(2)]],
device half* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
threadgroup half* tgp_memory [[threadgroup(0)]],
Expand All @@ -52,7 +52,7 @@ kernel void kernel_vecmat(
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {{
// Appease compiler
// Appease compiler
(void)simd_gid;
(void)simd_lid;
Expand Down Expand Up @@ -123,7 +123,7 @@ kernel void kernel_vecmat(
// Threadgroup accumulation and writing out results
if(lid.y == 0 && out_col < out_vec_size) {{
#pragma unroll(BM)
for(int i = 1; i < BM; i++) {{
#pragma unroll(TN)
Expand Down Expand Up @@ -151,7 +151,6 @@ impl MetalKernelForward for MetalVecMat {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[1].shape()[1].clone() * size_of::<f16>()]
}

fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
Expand Down Expand Up @@ -194,7 +193,7 @@ impl MetalKernelForward for MetalVecMat {
}

impl Operator for MetalVecMat {
fn process(&self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
// Setup command queue / command buffer / encoder
let command_buffer = self.queue.new_command_buffer();
Expand Down Expand Up @@ -373,7 +372,7 @@ impl MetalKernelForward for MetalMatmul2D {
}

impl Operator for MetalMatmul2D {
fn process(&self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
// Setup command queue / command buffer / encoder
let command_buffer = self.queue.new_command_buffer();
Expand Down Expand Up @@ -549,7 +548,7 @@ impl MetalKernelForward for MetalBatchMatmul2D {
}

impl Operator for MetalBatchMatmul2D {
fn process(&self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
// Setup command queue / command buffer / encoder
let command_buffer = self.1.new_command_buffer();
Expand Down Expand Up @@ -982,9 +981,6 @@ impl Compiler for MetalMatMulCompiler {

#[cfg(test)]
mod tests {

use crate::tests::assert_close_precision;

crate::test_imports!();
#[test]
fn test_matrix_vector() {
Expand Down
2 changes: 1 addition & 1 deletion src/compilers/metal/fp16/mean_reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ impl MetalKernelForward for MetalMeanReduce {
}

impl Operator for MetalMeanReduce {
fn process(&self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
fn process(&mut self, tensors: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
// Setup buffers
let a = tensors[0]
Expand Down
1 change: 1 addition & 0 deletions src/compilers/metal/fp16/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub type MetalFp16Compiler = (
rms_norm::RMSNormCompiler,
super::other::CopyCompiler<f16>,
super::command_buffer::CommandBufferCompiler,
super::storage_buffer::StorageBufferCompiler,
);

#[cfg(test)]
Expand Down
Loading

0 comments on commit ee17a48

Please sign in to comment.