Skip to content

Commit

Permalink
Changed matvec
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jan 12, 2024
1 parent e5dcff3 commit 9295ff8
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 60 deletions.
6 changes: 3 additions & 3 deletions examples/mistral/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ fn main() {
## The Laws
The Three Laws, presented to be from the fictional \"Handbook of Robotics, 56th Edition, 2058 A.D.\", are:
- The First Law: A robot may not injure a human being or, through inaction, allow a human being to come to harm.
- The Second Law: A robot must obey the orders given it by human beings except where such orders would conflict with the First Law.
- The Third Law: A robot must protect its own existence as long as such protection does not conflict with the First or Second Law.
- The First Law: A robot may not injure a human being or, through inaction, allow a human being to come to harm.
- The Second Law: A robot must obey the orders given it by human beings except where such orders would conflict with the First Law.
- The Third Law: A robot must protect its own existence as long as such protection does not conflict with the First or Second Law.
";
let tokens_to_generate = 128;

Expand Down
107 changes: 50 additions & 57 deletions src/compilers/metal/fp16/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,52 +29,59 @@ impl MatVec1Row {
#include <metal_simdgroup>
using namespace metal;
void matvec(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]
kernel void matvec(
device const char* mat_bytes [[buffer(0)]],
device const char* vec_bytes [[buffer(1)]],
device half* dst [[buffer(2)]],
constant int& M [[buffer(3)]],
uint3 threadgroup_pos[[threadgroup_position_in_grid]],
uint3 thread_pos[[thread_position_in_threadgroup]],
uint simd_pos[[thread_index_in_simdgroup]],
threadgroup half* tgp_memory [[threadgroup(0)]]
) {
const uint r2 = 0;
const uint r3 = 0;
int chunk_offset = thread_pos.z * (M / 4);
device const half4* mat = (device const half4*)(mat_bytes + threadgroup_pos.x * M * 2 + chunk_offset);
device const half4* vec = (device const half4*)(vec_bytes + chunk_offset);
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
half4 sum4 = 0;
for (int i = simd_pos; i < M/32; i += 32) {
sum4 += mat[i] * vec[i];
}
half sum = sum4[0] + sum4[1] + sum4[2] + sum4[3];
half all_sum = simd_sum(sum);
if (simd_pos == 0) {
tgp_memory[thread_pos.z] = all_sum;
}
threadgroup_barrier(mem_flags::mem_none);
device const half * x = (device const half *) (src0 + offset0);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
if (simd_pos == 0 && thread_pos.z == 0) {
half final_sum = 0;
#pragma unroll(8)
for (int i = 0; i < 8; ++i) {
final_sum += tgp_memory[i];
}
dst[threadgroup_pos.x] = final_sum;
}
}
float sumf = 0;
device const half4 * x4 = (device const half4 *) x;
device const float4 * y4 = (device const float4 *) y;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
// Simpler version of this kernel is ~5ms slower
kernel void matvec_simple(
device const half4* mat [[buffer(0)]],
device const half4* vec [[buffer(1)]],
device half* dst [[buffer(2)]],
constant int& M [[buffer(3)]],
uint3 threadgroup_pos[[threadgroup_position_in_grid]],
uint simd_pos[[thread_index_in_simdgroup]]
) {
mat += (threadgroup_pos.x * M) / 4;
half4 sumf = 0;
for (int i = simd_pos; i < M/4; i += 32) {
sumf += mat[i] * vec[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
half sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
half all_sum = simd_sum(sum);
if (simd_pos == 0) {
dst[threadgroup_pos.x] = all_sum;
}
}
",
Expand Down Expand Up @@ -108,24 +115,10 @@ impl MetalKernel for MatVec1Row {
encoder.set_buffer(1, Some(inputs[0].0), 0);
encoder.set_buffer(2, Some(output_buffers[0]), 0);
encoder.set_i32(3, m as i32);
encoder.set_i32(4, n as i32);
encoder.set_i32(5, 0 as i32);
encoder.set_i32(6, 0 as i32);
encoder.set_threadgroup_memory_length(
0,
if inputs[1].1.is_contiguous() {
BN * BM * 4
} else {
BN * 8
},
);
encoder.set_threadgroup_memory_length(0, (8 * std::mem::size_of::<f16>()) as u64);

encoder.set_compute_pipeline_state(&self.pipeline);
let b = if inputs[1].1.is_contiguous() { BN } else { BM };
encoder.dispatch_thread_groups(
MTLSize::new((n as u64 + b * 4 - 1).div_ceil(b * 4), 1, 1),
MTLSize::new(BN, BM, 1),
);
encoder.dispatch_thread_groups(MTLSize::new(n as u64, 1, 1), MTLSize::new(1, 32, 8));
encoder.end_encoding();
}
}
Expand Down

0 comments on commit 9295ff8

Please sign in to comment.