Skip to content

Commit

Permalink
Small kernel simplifications
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Jan 10, 2024
1 parent 3670378 commit 4dd7cd7
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/compilers/metal/fp16/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,22 @@ kernel void kernel_vecmat(
const device half4* in_vec [[buffer(0)]],
const device half4* mat [[buffer(1)]],
device half4* out_vec [[buffer(2)]],
const constant int& in_vec_size_divided_by_4 [[buffer(3)]],
const constant int& out_vec_size_divided_by_4 [[buffer(4)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
threadgroup half4* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {{
int in_vec_size_divided_by_4 = in_vec_size / 4;
int out_vec_size_divided_by_4 = out_vec_size / 4;
uint out_col = tid.x * BN + lid.x;
// Thread local accumulation results
half4 result = 0;
// Threadgroup accumulation results
threadgroup half4* tgp_results = tgp_memory + lid.x * BM;
uint out_col = tid.x * BN + lid.x;
// Per thread accumulation main loop
for(int bm = lid.y; bm < in_vec_size_divided_by_4; bm += BM) {{
#pragma unroll(TM)
Expand All @@ -70,7 +72,7 @@ kernel void kernel_vecmat(
threadgroup_barrier(mem_flags::mem_threadgroup);
// Threadgroup accumulation and writing out results
if(lid.y == 0 && out_col * TN < out_vec_size_divided_by_4 * 4) {{
if(lid.y == 0 && out_col * TN < out_vec_size) {{
#pragma unroll(BM)
for(int i = 1; i < BM; i++) {{
result += tgp_results[i];
Expand Down Expand Up @@ -112,8 +114,8 @@ impl MetalKernel for MetalVecMat {
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(inputs[1].0), 0);
encoder.set_buffer(2, Some(output_buffers[0]), 0);
encoder.set_u32(3, (m / 4) as u32);
encoder.set_u32(4, (n / 4) as u32);
encoder.set_u32(3, m as u32);
encoder.set_u32(4, n as u32);
encoder.set_threadgroup_memory_length(0, BN * BM * 8);

encoder.set_compute_pipeline_state(&self.kernel);
Expand Down

0 comments on commit 4dd7cd7

Please sign in to comment.