Skip to content

Commit

Permalink
Improvements to vecmat
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Jan 10, 2024
1 parent 7a9f9e0 commit 40a62e7
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 165 deletions.
2 changes: 1 addition & 1 deletion examples/mistral/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type DeviceCompiler = CPUCompiler;

fn main() {
let prompt = "[INST]Write me a python implementation of merge sort[/INST]\n";
let tokens_to_generate = 150;
let tokens_to_generate = 50;

let tokenizer = SentencePieceBpeTokenizer::from_file(
"./examples/mistral/setup/mistral-7b-hf/tokenizer.model",
Expand Down
8 changes: 4 additions & 4 deletions src/compilers/metal/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl<T> MetalKernel for MetalSub<T> {
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(inputs[1].0), 0);
encoder.set_buffer(2, Some(output_buffers[0]), 0);
encoder.set_int(3, inp_size as u32);
encoder.set_u32(3, inp_size as u32);
input_dyn_dims(&self.3, unsafe { self.5.as_ref().unwrap() }, &encoder, 4);

// Execute
Expand Down Expand Up @@ -291,7 +291,7 @@ impl<T> MetalKernel for MetalEqual<T> {
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(inputs[1].0), 0);
encoder.set_buffer(2, Some(output_buffers[0]), 0);
encoder.set_int(3, inp_size as u32);
encoder.set_u32(3, inp_size as u32);
input_dyn_dims(&self.3, unsafe { self.5.as_ref().unwrap() }, &encoder, 4);

// Execute
Expand Down Expand Up @@ -532,8 +532,8 @@ impl<T: MetalFloat> Operator for MetalGather<T> {
encoder.set_buffer(0, Some(&index_buffer), 0);
encoder.set_buffer(1, Some(b_inp), 0);
encoder.set_buffer(2, Some(&out), 0);
encoder.set_int(3, indexes.len() as u32);
encoder.set_int(4, self.embed_dim as u32);
encoder.set_u32(3, indexes.len() as u32);
encoder.set_u32(4, self.embed_dim as u32);

// Execute
encoder.dispatch_threads(
Expand Down
137 changes: 39 additions & 98 deletions src/compilers/metal/fp16/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ pub struct MetalVecMat {
}

const BM: u64 = 8;
const BN: u64 = 8;
const TM: u64 = 4;
const TN: u64 = 4;
const BN: u64 = 32;
impl MetalVecMat {
fn new(dev: &Device, queue: CommandQueue) -> Self {
Self {
Expand All @@ -35,108 +33,51 @@ impl MetalVecMat {
#include <metal_simdgroup>
using namespace metal;
static constant constexpr const int BM = {BM};
static constant constexpr const int BN = {BN};
static constant constexpr const int TM = {TM};
static constant constexpr const int TN = {TN};
#define BM {BM}
#define BN {BN}
#define TM 4
#define TN 4
kernel void kernel_vecmat(
const device half* in_vec [[buffer(0)]],
const device half* mat [[buffer(1)]],
device half* out_vec [[buffer(2)]],
const constant int& in_vec_size [[buffer(3)]],
const constant int& out_vec_size [[buffer(4)]],
threadgroup half* tgp_memory [[threadgroup(0)]],
const device half4* in_vec [[buffer(0)]],
const device half4* mat [[buffer(1)]],
device half4* out_vec [[buffer(2)]],
const constant int& in_vec_size_divided_by_4 [[buffer(3)]],
const constant int& out_vec_size_divided_by_4 [[buffer(4)]],
threadgroup half4* tgp_memory [[threadgroup(0)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {{
uint3 lid [[thread_position_in_threadgroup]]) {{
// Appease compiler
(void)simd_gid;
(void)simd_lid;
// Thread local accumulation results
half4 result = 0;
// Thread local accumulation results
half result[TN] = {{0}};
half inter[TN];
half v_coeff[TM];
// Threadgroup accumulation results
threadgroup half4* tgp_results = tgp_memory + lid.x * BM;
// Threadgroup accumulation results
threadgroup half* tgp_results = tgp_memory + lid.x * BM * TN;
uint out_col = (tid.x * BN + lid.x) * TN;
uint in_row = lid.y * TM;
// Edgecase handling
if (out_col < out_vec_size) {{
out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
uint out_col = tid.x * BN + lid.x;
// Per thread accumulation main loop
int bm = in_row;
for(; bm < in_vec_size; bm += BM * TM) {{
// Adding a threadgroup_barrier improves performance slightly
// This is possibly it may help exploit cache better
threadgroup_barrier(mem_flags::mem_none);
if(bm + TM <= in_vec_size) {{
#pragma unroll(TM)
for(int tm = 0; tm < TM; tm++) {{
v_coeff[tm] = in_vec[bm + tm];
}}
for(int bm = lid.y; bm < in_vec_size_divided_by_4; bm += BM) {{
#pragma unroll(TM)
for(int tm = 0; tm < TM; tm++) {{
#pragma unroll(TN)
for(int tn = 0; tn < TN; tn++) {{
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
}}
#pragma unroll(TN)
for(int tn = 0; tn < TN; tn++) {{
result[tn] += v_coeff[tm] * inter[tn];
}}
}}
}} else {{ // Edgecase handling
for(int tm = 0; bm + tm < in_vec_size; tm++) {{
v_coeff[tm] = in_vec[bm + tm];
#pragma unroll(TN)
for(int tn = 0; tn < TN; tn++) {{
inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn];
}}
#pragma unroll(TN)
for(int tn = 0; tn < TN; tn++) {{
result[tn] += v_coeff[tm] * inter[tn];
}}
result += mat[(bm_tn * 4 + tm) * out_vec_size_divided_by_4 + out_col] * in_vec[bm][tm];
}}
}}
}}
}}
// Threadgroup collection
#pragma unroll(TN)
for(int i = 0; i < TN; i++) {{
tgp_results[lid.y * TN + i] = result[i];
}}
// Threadgroup collection
tgp_results[lid.y] = result;
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_threadgroup);
// Threadgroup accumulation and writing out results
if(lid.y == 0 && out_col < out_vec_size) {{
#pragma unroll(BM)
for(int i = 1; i < BM; i++) {{
#pragma unroll(TN)
for(int j = 0; j < TN; j++) {{
result[j] += tgp_results[i * TN + j];
}}
}}
// Threadgroup accumulation and writing out results
if(lid.y == 0 && out_col * TN < out_vec_size_divided_by_4 * 4) {{
#pragma unroll(BM)
for(int i = 1; i < BM; i++) {{
result += tgp_results[i];
}}
#pragma unroll(TN)
for(int j = 0; j < TN; j++) {{
out_vec[out_col + j] = result[j];
out_vec[out_col] = result;
}}
}}
}}"
),
dev,
Expand Down Expand Up @@ -171,14 +112,14 @@ impl MetalKernel for MetalVecMat {
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(inputs[1].0), 0);
encoder.set_buffer(2, Some(output_buffers[0]), 0);
encoder.set_int(3, m as u32);
encoder.set_int(4, n as u32);
encoder.set_threadgroup_memory_length(0, BN * TN * BM * TM);
encoder.set_u32(3, (m / 4) as u32);
encoder.set_u32(4, (n / 4) as u32);
encoder.set_threadgroup_memory_length(0, BN * BM * 8);

encoder.set_compute_pipeline_state(&self.kernel);
encoder.dispatch_thread_groups(
MTLSize {
width: (n as u64).div_ceil(BN * TN),
width: (n as u64).div_ceil(BN * 4),
height: 1,
depth: 1,
},
Expand Down Expand Up @@ -346,9 +287,9 @@ impl MetalKernel for MetalMatmul2D {
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(inputs[1].0), 0);
encoder.set_buffer(2, Some(output_buffers[0]), 0);
encoder.set_int(3, m as u32);
encoder.set_int(4, n as u32);
encoder.set_int(5, k as u32);
encoder.set_u32(3, m as u32);
encoder.set_u32(4, n as u32);
encoder.set_u32(5, k as u32);

encoder.set_compute_pipeline_state(&self.simd_shader);
encoder.dispatch_thread_groups(
Expand Down Expand Up @@ -518,9 +459,9 @@ impl MetalKernel for MetalBatchMatmul2D {
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(inputs[1].0), 0);
encoder.set_buffer(2, Some(output_buffers[0]), 0);
encoder.set_int(3, m as u32);
encoder.set_int(4, n as u32);
encoder.set_int(5, k as u32);
encoder.set_u32(3, m as u32);
encoder.set_u32(4, n as u32);
encoder.set_u32(5, k as u32);

// Execute
encoder.dispatch_thread_groups(
Expand Down
8 changes: 4 additions & 4 deletions src/compilers/metal/fp16/mean_reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ impl MetalKernel for MetalMeanReduce {
// Set inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(output_buffers[0]), 0);
encoder.set_int(2, inp_size as u32);
encoder.set_int(3, front_size as u32);
encoder.set_int(4, back_size as u32);
encoder.set_int(5, dim_size as u32);
encoder.set_u32(2, inp_size as u32);
encoder.set_u32(3, front_size as u32);
encoder.set_u32(4, back_size as u32);
encoder.set_u32(5, dim_size as u32);
input_dyn_dims(&self.4, unsafe { self.5.as_ref().unwrap() }, encoder, 6);

// Execute
Expand Down
32 changes: 11 additions & 21 deletions src/compilers/metal/fp16/rms_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ pub struct MetalRMSNorm {
impl MetalRMSNorm {
fn new(epsilon: f32, device: Device, queue: CommandQueue) -> Self {
let kernel_code = "#include <metal_stdlib>
#define SIMD_WIDTH 32
using namespace metal;
kernel void kernel_rms_norm(
device const void * src0,
device half * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
constant float & eps,
device const void * src0 [[buffer(0)]],
device half * dst [[buffer(1)]],
constant int64_t & ne00 [[buffer(2)]],
constant uint64_t & nb01 [[buffer(3)]],
constant float & eps [[buffer(4)]],
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
Expand All @@ -48,7 +50,7 @@ kernel void kernel_rms_norm(
}
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
all_sum = simd_sum(all_sum);
if (ntg > 32) {
if (ntg > SIMD_WIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
Expand Down Expand Up @@ -104,21 +106,9 @@ impl MetalKernel for MetalRMSNorm {
// Set inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
encoder.set_buffer(1, Some(output_buffers[0]), 0);
encoder.set_bytes(
2,
size_of::<i64>() as u64,
&(ne00 as i64) as *const i64 as *const _,
);
encoder.set_bytes(
3,
size_of::<u64>() as u64,
&(nb01 as u64) as *const u64 as *const _,
);
encoder.set_bytes(
4,
size_of::<f32>() as u64,
&self.epsilon as *const f32 as *const _,
);
encoder.set_i64(2, ne00 as i64);
encoder.set_u64(3, nb01 as u64);
encoder.set_f32(4, self.epsilon);

let mut nth = 32; // SIMD width
while nth < ne00 / 4 && nth < 1024 {
Expand Down
24 changes: 12 additions & 12 deletions src/compilers/metal/fp32/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ impl Operator for MetalMatmul2D {
encoder.set_buffer(0, Some(a), 0);
encoder.set_buffer(1, Some(b), 0);
encoder.set_buffer(2, Some(&out), 0);
encoder.set_int(3, m as u32);
encoder.set_int(4, n as u32);
encoder.set_int(5, k as u32);
encoder.set_u32(3, m as u32);
encoder.set_u32(4, n as u32);
encoder.set_u32(5, k as u32);

if k >= 16 && n >= 256 && ((n != 0) && (n & (n - 1)) == 0) {
encoder.set_compute_pipeline_state(&self.simd_shader);
Expand Down Expand Up @@ -279,15 +279,15 @@ impl Operator for MetalBatchMatmul2D {
encoder.set_buffer(0, Some(a), 0);
encoder.set_buffer(1, Some(b), 0);
encoder.set_buffer(2, Some(&out), 0);
encoder.set_int(3, batch_size as u32);
encoder.set_int(4, m as u32);
encoder.set_int(5, k as u32);
encoder.set_int(6, n as u32);
encoder.set_int(7, a_row_major as u32);
encoder.set_int(8, b_row_major as u32);
encoder.set_int(9, a_strides[0].to_usize().unwrap() as u32);
encoder.set_int(10, 0);
encoder.set_int(11, (m * n) as u32);
encoder.set_u32(3, batch_size as u32);
encoder.set_u32(4, m as u32);
encoder.set_u32(5, k as u32);
encoder.set_u32(6, n as u32);
encoder.set_u32(7, a_row_major as u32);
encoder.set_u32(8, b_row_major as u32);
encoder.set_u32(9, a_strides[0].to_usize().unwrap() as u32);
encoder.set_u32(10, 0);
encoder.set_u32(11, (m * n) as u32);

// Execute
encoder.dispatch_1d(batch_size * n * m);
Expand Down
Loading

0 comments on commit 40a62e7

Please sign in to comment.