Skip to content

Commit

Permalink
Added support for transpose in matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Jan 12, 2024
1 parent b3e07bd commit c506d1e
Showing 1 changed file with 9 additions and 73 deletions.
82 changes: 9 additions & 73 deletions src/compilers/metal/fp16/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,43 +486,11 @@ impl Compiler for MetalMatMulCompiler {
}
src2_shape.remove_dim(0);
src2_shape.permute(&[1, 0]);
// Pad out N to multiple of 256 and K to 16
let n_dim = Expression::from(src2_shape.shape()[1].clone());
let k_dim = Expression::from(src2_shape.shape()[0].clone());
let m_dim = if src1_shape.len() == 3 {
Expression::from(src1_shape.shape()[1].clone())
} else {
Expression::from(src1_shape.shape()[0].clone())
};
let mut padded = false;
let k_padding = if k_dim.to_usize().map(|i| i % 16 != 0).unwrap_or(true) {
(k_dim + 15) / 16 * 16 - k_dim
} else {
0.into()
};
let m_padding = if m_dim.to_usize().map(|i| i % 32 != 0).unwrap_or(true) {
padded = true;
(m_dim + 31) / 32 * 32 - m_dim
} else {
0.into()
};
let n_padding = if n_dim.to_usize().map(|i| i % 256 != 0).unwrap_or(true) {
padded = true;
(n_dim + 255) / 256 * 256 - n_dim
} else {
0.into()
};
if src1_shape.len() == 2 {
src1_shape.pad(&[(0.into(), m_padding), (0.into(), k_padding)]);
} else {
src1_shape.pad(&[
(0.into(), 0.into()),
(0.into(), m_padding),
(0.into(), k_padding),
]);
}
src2_shape.pad(&[(0.into(), k_padding), (0.into(), n_padding)]);
if !src1_shape.is_contiguous() || src1_shape.is_sliced() || src1_shape.is_padded() {
// If src1 is padded or sliced, or batch dim isn't first, we need to make it contiguous
if (src1_shape.len() == 3 && src1_shape.indexes[0] != 0)
|| src1_shape.is_sliced()
|| src1_shape.is_padded()
{
src1 = graph
.add_op(MetalContiguous::<f16>::new(
src1_shape,
Expand All @@ -535,7 +503,8 @@ impl Compiler for MetalMatMulCompiler {
.finish();
src1_shape = src1_shape.contiguous();
}
if !src2_shape.is_contiguous() || src2_shape.is_sliced() || src2_shape.is_padded() {
// If src1 is padded or sliced we need to make it contiguous
if src2_shape.is_sliced() || src2_shape.is_padded() {
src2 = graph
.add_op(MetalContiguous::<f16>::new(
src2_shape,
Expand All @@ -552,7 +521,7 @@ impl Compiler for MetalMatMulCompiler {
pipeline_state_descriptor.set_compute_function(Some(
&matmul_library
.get_function(
"gemm_nn_float16_float16_bm32_bn32_bk16_wm2_wn2_MN_naligned_K_taligned",
&format!( "gemm_{}{}_float16_float16_bm32_bn32_bk16_wm2_wn2_MN_naligned_K_taligned", if src1_shape.is_contiguous() {"n"} else {"t"}, if src2_shape.is_contiguous() {"n"} else {"t"}),
None,
)
.unwrap(),
Expand All @@ -562,44 +531,11 @@ impl Compiler for MetalMatMulCompiler {
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let mut matmul_op = graph
let matmul_op = graph
.add_op(Matmul(pipeline, queue.clone(), dev.clone()))
.input(src1, 0, src1_shape)
.input(src2, 0, src2_shape)
.finish();
// Slice back to original size
if padded {
let new_shape = if src1_shape.len() == 3 {
let mut n = ShapeTracker::new(&[
Expression::from(src1_shape.shape()[0].clone()),
Expression::from(src1_shape.shape()[1].clone()),
Expression::from(src2_shape.shape()[1].clone()),
]);
n.slice(&[
(0.into(), i32::MAX.into()),
(0.into(), m_dim),
(0.into(), n_dim),
]);
n
} else {
let mut n = ShapeTracker::new(&[
Expression::from(src1_shape.shape()[0].clone()),
Expression::from(src2_shape.shape()[1].clone()),
]);
n.slice(&[(0.into(), i32::MAX.into()), (0.into(), n_dim)]);
n
};
matmul_op = graph
.add_op(MetalContiguous::<f16>::new(
new_shape,
dev.clone(),
queue.clone(),
&mut HashMap::new(),
&graph.dyn_map,
))
.input(matmul_op, 0, new_shape)
.finish();
}

// Create edges to dests
move_outgoing_edge(sum_reduce, matmul_op, &mut graph.graph);
Expand Down

0 comments on commit c506d1e

Please sign in to comment.