Skip to content

Commit

Permalink
Kernel cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Feb 13, 2024
1 parent bd56364 commit 3bd99c9
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions src/compilers/metal/quantized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ typedef struct {
} block_q8_0;
kernel void mkernel(
device const void* src0 [[buffer(0)]], // Quantized 2D matrix
device const half* src1 [[buffer(1)]], // Float src vector
device half* dst [[buffer(2)]], // Float dest vector
constant int64_t & src_vec_size [[buffer(3)]], // Matrix n cols (src vector size) (Must be >= 32)
constant int64_t & dest_vec_size [[buffer(4)]], // Matrix n rows (dest vector size) (Must be >= 4)
constant int64_t & mat_batch_stride [[buffer(5)]], // Matrix batch stride
constant int64_t & vec_batch_stride [[buffer(6)]], // Vector batch stride
device block_q8_0* x [[buffer(0)]], // Quantized 2D matrix
device half* y [[buffer(1)]], // Float src vector
device half* dst [[buffer(2)]], // Float dest vector
constant int64_t & src_vec_size [[buffer(3)]], // Matrix n cols (src vector size) (Must be >= 32)
constant int64_t & dest_vec_size [[buffer(4)]], // Matrix n rows (dest vector size) (Must be >= 4)
constant int64_t & mat_batch_stride [[buffer(5)]], // Matrix batch stride
constant int64_t & vec_batch_stride [[buffer(6)]], // Vector batch stride
uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint thread_index_in_simdgroup[[thread_index_in_simdgroup]],
uint simdgroup_index_in_threadgroup [[simdgroup_index_in_threadgroup]] // 2 simdgroups in a threadgroup
Expand All @@ -57,9 +57,9 @@ kernel void mkernel(
// This is the first row the simdgroup will work on (each simdgroup handles a block of 4 rows)
const int first_row = (threadgroup_position_in_grid.x * num_simdgroups_per_threadgroup + simdgroup_index_in_threadgroup) * num_rows;
// Offset in number of quant blocks
device const block_q8_0* x = ((device const block_q8_0*)src0) + (first_row * num_quants_per_row); // Add batch offset here
device const half* y = src1 + (threadgroup_position_in_grid.z * vec_batch_stride); // Add batch offset here
// Offsets
x += first_row * num_quants_per_row + threadgroup_position_in_grid.z * (mat_batch_stride / 32);
y += threadgroup_position_in_grid.z * vec_batch_stride;
dst += (threadgroup_position_in_grid.z * dest_vec_size);
// thread-local cache of vector values to work on. This thread must only work on 8 at a time
Expand Down

0 comments on commit 3bd99c9

Please sign in to comment.