From 4dd7cd7cfd0997a2f16b36e05453a414e9ea1fb3 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Tue, 9 Jan 2024 22:22:36 -0600 Subject: [PATCH] Small kernel simplifications --- src/compilers/metal/fp16/matmul.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/compilers/metal/fp16/matmul.rs b/src/compilers/metal/fp16/matmul.rs index 16948cf1..34a8364a 100644 --- a/src/compilers/metal/fp16/matmul.rs +++ b/src/compilers/metal/fp16/matmul.rs @@ -42,20 +42,22 @@ kernel void kernel_vecmat( const device half4* in_vec [[buffer(0)]], const device half4* mat [[buffer(1)]], device half4* out_vec [[buffer(2)]], - const constant int& in_vec_size_divided_by_4 [[buffer(3)]], - const constant int& out_vec_size_divided_by_4 [[buffer(4)]], + const constant int& in_vec_size [[buffer(3)]], + const constant int& out_vec_size [[buffer(4)]], threadgroup half4* tgp_memory [[threadgroup(0)]], uint3 tid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{ + int in_vec_size_divided_by_4 = in_vec_size / 4; + int out_vec_size_divided_by_4 = out_vec_size / 4; + uint out_col = tid.x * BN + lid.x; + // Thread local accumulation results half4 result = 0; // Threadgroup accumulation results threadgroup half4* tgp_results = tgp_memory + lid.x * BM; - uint out_col = tid.x * BN + lid.x; - // Per thread accumulation main loop for(int bm = lid.y; bm < in_vec_size_divided_by_4; bm += BM) {{ #pragma unroll(TM) @@ -70,7 +72,7 @@ kernel void kernel_vecmat( threadgroup_barrier(mem_flags::mem_threadgroup); // Threadgroup accumulation and writing out results - if(lid.y == 0 && out_col * TN < out_vec_size_divided_by_4 * 4) {{ + if(lid.y == 0 && out_col * TN < out_vec_size) {{ #pragma unroll(BM) for(int i = 1; i < BM; i++) {{ result += tgp_results[i]; @@ -112,8 +114,8 @@ impl MetalKernel for MetalVecMat { encoder.set_buffer(0, Some(inputs[0].0), 0); encoder.set_buffer(1, Some(inputs[1].0), 0); encoder.set_buffer(2, Some(output_buffers[0]), 0); - encoder.set_u32(3, (m / 4) as u32); - encoder.set_u32(4, (n / 4) as u32); + encoder.set_u32(3, m as u32); + encoder.set_u32(4, n as u32); encoder.set_threadgroup_memory_length(0, BN * BM * 8); encoder.set_compute_pipeline_state(&self.kernel);