From b54da0ddde28ee55f0710bb32efba87ded491ef4 Mon Sep 17 00:00:00 2001 From: Joe Fioti Date: Fri, 12 Jan 2024 16:23:09 -0600 Subject: [PATCH] bring in line with ggml kernel --- src/compilers/metal/fp16/matmul.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/compilers/metal/fp16/matmul.rs b/src/compilers/metal/fp16/matmul.rs index 118a3cc8..94fa54c6 100644 --- a/src/compilers/metal/fp16/matmul.rs +++ b/src/compilers/metal/fp16/matmul.rs @@ -43,11 +43,10 @@ kernel void matvec( device const half4* mat = (device const half4*)(mat_bytes + threadgroup_pos.x * M * 2 + chunk_offset); device const half4* vec = (device const half4*)(vec_bytes + chunk_offset); - half4 sum4 = 0; + half sum = 0; for (int i = simd_pos; i < M/32; i += 32) { - sum4 += mat[i] * vec[i]; + for (int k = 0; k < 4; ++k) sum += mat[i][k] * vec[i][k]; } - half sum = sum4[0] + sum4[1] + sum4[2] + sum4[3]; half all_sum = simd_sum(sum); if (simd_pos == 0) { tgp_memory[thread_pos.z] = all_sum;