Skip to content

Commit

Permalink
Test commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Jan 12, 2024
1 parent ec09c02 commit cf0e6ad
Showing 1 changed file with 195 additions and 26 deletions.
221 changes: 195 additions & 26 deletions src/compilers/metal/fp16/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,164 @@ use crate::{

use metal_rs::{objc::rc::autoreleasepool, *};

/// Multiplies a M vector with a MxN matrix, resulting in a N vector. Expects the matrix to be NxM row-major
#[derive(LuminalEq, LuminalPrint, Clone)]
pub struct MatVec1Row {
pipeline: ComputePipelineState,
queue: CommandQueue,
device: Device,
}

impl MatVec1Row {
fn compile(device: &Device) -> ComputePipelineState {
compile_function(
"matvec",
"
#include <metal_stdlib>
#include <metal_simdgroup_matrix>
#include <metal_simdgroup>
using namespace metal;
void matvec(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]
) {
const uint r2 = 0;
const uint r3 = 0;
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z;
const uint i12 = im%ne12;
const uint i13 = im/ne12;
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
device const half * x = (device const half *) (src0 + offset0);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
device const half4 * x4 = (device const half4 *) x;
device const float4 * y4 = (device const float4 *) y;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
",
device,
)
}
}

impl MetalKernel for MatVec1Row {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
vec![input_shapes[1].shape()[1].clone() * size_of::<f16>()]
}

fn metal_forward(
&self,
inputs: &[(&Buffer, ShapeTracker)],
command_buffer: &CommandBufferRef,
_: &[&Buffer],
output_buffers: &[&Buffer],
) {
let (m, n) = (
inputs[0].1.shape()[0].to_usize().unwrap(),
inputs[1].1.shape()[1].to_usize().unwrap(),
);

let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());

// Set inputs
encoder.set_buffer(0, Some(inputs[1].0), 0);
encoder.set_buffer(1, Some(inputs[0].0), 0);
encoder.set_buffer(2, Some(output_buffers[0]), 0);
encoder.set_i32(3, m as i32);
encoder.set_i32(4, n as i32);
encoder.set_i32(5, 0 as i32);
encoder.set_i32(6, 0 as i32);
encoder.set_threadgroup_memory_length(
0,
if inputs[1].1.is_contiguous() {
BN * BM * 4
} else {
BN * 8
},
);

encoder.set_compute_pipeline_state(&self.pipeline);
let b = if inputs[1].1.is_contiguous() { BN } else { BM };
encoder.dispatch_thread_groups(
MTLSize::new((n as u64 + b * 4 - 1).div_ceil(b * 4), 1, 1),
MTLSize::new(BN, BM, 1),
);
encoder.end_encoding();
}
}

impl Operator for MatVec1Row {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
autoreleasepool(|| {
// Setup command queue / command buffer / encoder
let command_buffer = self.queue.new_command_buffer();

let n = inp[1].1.shape()[1].to_usize().unwrap();

let out = self.device.new_buffer(
(n * std::mem::size_of::<f16>()) as u64,
MTLResourceOptions::StorageModeShared,
);
self.metal_forward(
&[
(get_buffer_from_tensor(&inp[0].0), inp[0].1),
(get_buffer_from_tensor(&inp[1].0), inp[1].1),
],
command_buffer,
&[],
&[&out],
);

command_buffer.commit();
command_buffer.wait_until_completed();

vec![Tensor::new(out)]
})
}

fn custom(&mut self, key: &str, _: Box<dyn Any>) -> Option<Box<dyn Any>> {
if key == "metal" {
return Some(Box::new(MetalKernelWrapper(Arc::new(Box::new(
self.clone(),
)))));
}
None
}
}

/// Multiplies a M vector with a MxN matrix, resulting in a N vector. Expects the matrix to be NxM row-major
#[derive(LuminalEq, LuminalPrint, Clone)]
pub struct MatVec {
Expand Down Expand Up @@ -333,33 +491,44 @@ impl Compiler for MetalMatMulCompiler {
.finish();
src2_shape = src2_shape.contiguous();
}

let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(
&matvec_library
.get_function(
&format!(
"gemv_{}float16_bm{BM}_bn{BN}_tm4_tn4",
if src2_shape.is_contiguous() { "t_" } else { "" }
),
None,
let matmul_op = if !src2_shape.is_contiguous() && false {
graph
.add_op(MatVec1Row {
pipeline: MatVec1Row::compile(&dev),
device: dev.clone(),
queue: queue.clone(),
})
.input(src1, 0, src1_shape)
.input(src2, 0, src2_shape)
.finish()
} else {
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(
&matvec_library
.get_function(
&format!(
"gemv_{}float16_bm{BM}_bn{BN}_tm4_tn4",
if src2_shape.is_contiguous() { "t_" } else { "" }
),
None,
)
.unwrap(),
));
let pipeline = dev
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap(),
));
let pipeline = dev
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let matmul_op = graph
.add_op(MatVec {
pipeline,
device: dev.clone(),
queue: queue.clone(),
})
.input(src1, 0, src1_shape)
.input(src2, 0, src2_shape)
.finish();
.unwrap();
graph
.add_op(MatVec {
pipeline,
device: dev.clone(),
queue: queue.clone(),
})
.input(src1, 0, src1_shape)
.input(src2, 0, src2_shape)
.finish()
};

// Create edges to dests
move_outgoing_edge(sum_reduce, matmul_op, &mut graph.graph);
Expand Down

0 comments on commit cf0e6ad

Please sign in to comment.