Skip to content

Commit

Permalink
Generalized matmul compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jan 20, 2024
1 parent e89bdbb commit 8bd7598
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 79 deletions.
113 changes: 79 additions & 34 deletions src/compilers/metal/fp16/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,15 @@ const BM: u64 = 8;
const BN: u64 = 32;
impl MetalKernel for Matmul {
fn output_buffer_sizes(&self, input_shapes: &[ShapeTracker]) -> Vec<BigExpression> {
let n = input_shapes[1].shape()[1].clone();
let (batch_size, m) = if input_shapes[0].len() == 3 {
(
input_shapes[0].shape()[0].clone(),
input_shapes[0].shape()[1].clone(),
)
} else {
(1.into(), input_shapes[0].shape()[0].clone())
};
vec![BigExpression::from(m) * n * batch_size * size_of::<f16>()]
let m = input_shapes[0].shape()[input_shapes[0].len() - 2].clone();
let n = input_shapes[1].shape()[input_shapes[1].len() - 1].clone();
let batch_size = input_shapes[0]
.shape()
.into_iter()
.take(input_shapes[0].len() - 2)
.product::<BigExpression>()
.max(BigExpression::from(1));
vec![batch_size * m * n * size_of::<f16>()]
}
fn metal_forward(
&self,
Expand All @@ -69,18 +68,17 @@ impl MetalKernel for Matmul {
output_buffers: &[&Buffer],
) {
let (a_shape, b_shape) = (inputs[0].1.shape(), inputs[1].1.shape());
let (k, n) = (
b_shape[0].to_usize().unwrap(),
b_shape[1].to_usize().unwrap(),
);
let (batch_size, m) = if a_shape.len() == 3 {
(
a_shape[0].to_usize().unwrap(),
a_shape[1].to_usize().unwrap(),
)
} else {
(1, a_shape[0].to_usize().unwrap())
};
let a_dims = a_shape.len();
let m = a_shape[a_dims - 2].to_usize().unwrap();
let batch_size = a_shape
.iter()
.take(a_dims - 2)
.map(|i| i.to_usize().unwrap())
.product::<usize>()
.max(1);
let b_dims = b_shape.len();
let k = b_shape[b_dims - 2].to_usize().unwrap();
let n = b_shape[b_dims - 1].to_usize().unwrap();

let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
Expand Down Expand Up @@ -119,7 +117,14 @@ impl MetalKernel for Matmul {
encoder.set_i32(4, n as i32);
encoder.set_i32(5, k as i32);
encoder.set_i32(6, (m * k) as i32); // A batch stride
encoder.set_i32(7, 0); // B batch stride
encoder.set_i32(
7,
if inputs[1].1.len() == 2 {
0
} else {
(k * n) as i32
},
); // B batch stride
encoder.set_i32(8, (m * n) as i32); // C batch stride

// Execute
Expand Down Expand Up @@ -228,7 +233,7 @@ impl Compiler for MetalMatMulCompiler {
])
.fakes(vec![
vec![Some(false), Some(false), Some(true), Some(false)],
vec![Some(true), Some(true), Some(false), Some(false)],
vec![None, Some(true), Some(false), Some(false)],
])
.ptr(&mut mul)
.edge(
Expand All @@ -244,8 +249,41 @@ impl Compiler for MetalMatMulCompiler {
.ptr(&mut sum_reduce),
)
.search(graph);
let mut batch_batch_searcher = SelectOp::new()
.ty::<MetalMul<f16>>()
.shapes(vec![
vec!['E'.into(), 'D'.into(), 'A'.into(), 'C'.into(), 'B'.into()],
vec!['E'.into(), 'D'.into(), 'A'.into(), 'C'.into(), 'B'.into()],
])
.fakes(vec![
vec![
Some(false),
Some(false),
Some(false),
Some(true),
Some(false),
],
vec![None, None, Some(true), Some(false), Some(false)],
])
.ptr(&mut mul)
.edge(
SelectOp::new()
.ty::<MetalSumReduce<f16>>()
.check(|o, _| {
if let Some(o) = o.as_any().downcast_ref::<MetalSumReduce<f16>>() {
o.dim == 4
} else {
false
}
})
.ptr(&mut sum_reduce),
)
.search(graph);
let (matmul_library, matvec_library) = compile_libs(&dev);
while single_searcher.next_match() || batch_searcher.next_match() {
while single_searcher.next_match()
|| batch_searcher.next_match()
|| batch_batch_searcher.next_match()
{
if graph.no_delete.contains(&mul) {
// The intermediate mul can't be deleted
continue;
Expand All @@ -255,14 +293,18 @@ impl Compiler for MetalMatMulCompiler {
let (mut src1, mut src1_shape) = (srcs[0].0, srcs[0].2);
let (mut src2, mut src2_shape) = (srcs[1].0, srcs[1].2);
// Undo expansions and permute
src1_shape.remove_dim(if src1_shape.len() == 4 { 2 } else { 1 });
if src2_shape.len() == 4 {
src2_shape.remove_dim(1);
}
src2_shape.remove_dim(0);
src2_shape.permute(&[1, 0]);
src1_shape.remove_dim(src1_shape.len() - 2);
src2_shape.remove_dim(src2_shape.len() - 3);
let mut dims = (0..src2_shape.len()).collect_vec();
dims.swap(src2_shape.len() - 2, src2_shape.len() - 1);
src2_shape.permute(&dims);
// If src1 is padded or sliced, or batch dim isn't first, we need to make it contiguous
if (src1_shape.len() == 3 && src1_shape.indexes[0] != 0)
if (src1_shape
.indexes
.iter()
.take(src1_shape.len() - 2)
.enumerate()
.any(|(a, b)| a != *b))
|| src1_shape.is_sliced()
|| src1_shape.is_padded()
{
Expand All @@ -277,8 +319,11 @@ impl Compiler for MetalMatMulCompiler {
.finish();
src1_shape = src1_shape.contiguous();
}
// If src1 is padded or sliced we need to make it contiguous
if src2_shape.is_sliced() || src2_shape.is_padded() {
// If src2 is padded or sliced, or batch dim isn't first, we need to make it contiguous
if (src2_shape.len() == 3 && src2_shape.indexes[0] != 0)
|| src2_shape.is_sliced()
|| src2_shape.is_padded()
{
src2 = graph
.add_op(MetalContiguous::<f16>::new(
src2_shape,
Expand Down
5 changes: 3 additions & 2 deletions src/compilers/metal/fp16/rms_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ impl MetalKernel for MetalRMSNorm {
let encoder =
command_buffer.compute_command_encoder_with_descriptor(ComputePassDescriptor::new());
encoder.set_compute_pipeline_state(&self.pipeline);
let ne00 = inputs[0].1.shape()[2].to_usize().unwrap();
let nb01 = inputs[0].1.strides()[1].to_usize().unwrap() * size_of::<f16>();
let n_dims = inputs[0].1.len();
let ne00 = inputs[0].1.shape()[n_dims - 1].to_usize().unwrap();
let nb01 = inputs[0].1.strides()[n_dims - 2].to_usize().unwrap() * size_of::<f16>();

// Set inputs
encoder.set_buffer(0, Some(inputs[0].0), 0);
Expand Down
78 changes: 37 additions & 41 deletions src/compilers/metal/fp16/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,40 +454,41 @@ fn test_matmul() {
}
}

// #[test]
// fn test_attn_matmul() {
// let mut cx = Graph::new();
// let mut rng = StdRng::seed_from_u64(0);
// let a_data = random_vec_rng(32 * 11 * 128, &mut rng);
// let b_data = random_vec_rng(32 * 11 * 128, &mut rng);
// let a = cx.new_tensor::<R4<1, 32, 11, 128>>("Input");
// a.set(a_data.clone());
// a.keep();
// let b = cx.new_tensor::<R4<1, 32, 128, 11>>("Input");
// b.set(b_data.clone());
// b.keep();
// let c = a.matmul(b);
// c.retrieve();

// cx.compile(MetalFp16Compiler::default());
// cx.execute();

// let d_dev = Cpu::default();
// let d_a = d_dev
// .tensor_from_vec(
// a_data,
// (DConst::<1>, DConst::<32>, DConst::<11>, DConst::<128>),
// )
// .to_dtype::<f16>();
// let d_b = d_dev
// .tensor_from_vec(
// b_data,
// (DConst::<1>, DConst::<32>, DConst::<128>, DConst::<11>),
// )
// .to_dtype::<f16>();
// let d_c = d_a.matmul(d_b);
// assert_exact(&c.data(), &d_c.to_dtype::<f32>().as_vec());
// }
#[test]
fn test_attn_matmul() {
let mut cx = Graph::new();
let mut rng = StdRng::seed_from_u64(0);
let a_data = random_vec_rng(32 * 11 * 128, &mut rng);
let b_data = random_vec_rng(32 * 11 * 128, &mut rng);
let a = cx
.named_tensor::<R4<1, 32, 11, 128>>("Input")
.set(a_data.clone())
.keep();
let b = cx
.named_tensor::<R4<1, 32, 128, 11>>("Input")
.set(b_data.clone())
.keep();
let mut c = a.matmul(b).retrieve();

cx.compile(MetalFp16Compiler::default(), &mut c);
cx.execute();

let d_dev = Cpu::default();
let d_a = d_dev
.tensor_from_vec(
a_data,
(DConst::<1>, DConst::<32>, DConst::<11>, DConst::<128>),
)
.to_dtype::<f16>();
let d_b = d_dev
.tensor_from_vec(
b_data,
(DConst::<1>, DConst::<32>, DConst::<128>, DConst::<11>),
)
.to_dtype::<f16>();
let d_c = d_a.matmul(d_b);
assert_exact(&c.data(), &d_c.to_dtype::<f32>().as_vec());
}

#[test]
fn test_batch_matmul() {
Expand Down Expand Up @@ -710,17 +711,13 @@ fn test_rms_norm() {
.to_dtype::<f16>();
let a = dev
.tensor_from_vec(inp_data, (DConst::<15>, DConst::<4>))
.to_dtype::<f16>()
.to_dtype::<f32>();
.to_dtype::<f16>();
let var_f32 = a.clone().square().mean::<_, DAxis<1>>();
let inv_std_f32 = (var_f32 + 1e-6).sqrt().recip();
let x_f32 = inv_std_f32.broadcast() * a;
let out = weight.broadcast() * x_f32.to_dtype::<f16>();

assert_exact(
&b.data().into_iter().map(f16::from_f32).collect::<Vec<_>>(),
&out.as_vec(),
);
assert_close(&b.data(), &out.to_dtype::<f32>().as_vec());
}

#[test]
Expand Down Expand Up @@ -821,7 +818,6 @@ fn test_transformer_encoder_block() {
let d_a = d_dev.tensor_from_vec(vec![-1., 2., 3., 3., 3., -1.], (DConst::<2>, DConst::<3>));
let d_b = d_model.forward(d_a);

// Annoyingly dfdx transformer encoder outputs 0s in fp16 mode, so we need to use the fp32 mode. Result ends up being close enough
assert_close(&b.data(), &d_b.as_vec());
}

Expand Down
4 changes: 2 additions & 2 deletions src/core/shape/tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ impl ShapeTracker {
.into_iter()
.map(|i| (i, BigExpression::from(self.dims[i])))
// Add pads
.map(|(ind, dim)| (ind, dim + self.padding[ind].0 + self.padding[ind].1))
.map(|(i, dim)| (i, dim + self.padding[i].0 + self.padding[i].1))
// Slice
.map(|(ind, dim)| dim.min(self.slices[ind].1) - self.slices[ind].0)
.map(|(i, dim)| dim.min(self.slices[i].1) - self.slices[i].0)
.product();
if r == 0.into() {
1.into()
Expand Down

0 comments on commit 8bd7598

Please sign in to comment.