Skip to content

Commit

Permalink
steel matmul is ass
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed Jan 23, 2024
1 parent 80915d3 commit 7bd8de2
Show file tree
Hide file tree
Showing 10 changed files with 1,521 additions and 0 deletions.
312 changes: 312 additions & 0 deletions src/compilers/metal/kernels/steel/gemm/gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
// Copyright © 2024 Apple Inc.

#pragma once

#include "loader.h"
#include "mma.h"
#include "transforms.h"
#include "../utils.h"

using namespace metal;

///////////////////////////////////////////////////////////////////////////////
// GEMM kernel class
///////////////////////////////////////////////////////////////////////////////

namespace mlx {
namespace steel {

template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};

template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct GEMMKernel {
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
STEEL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
STEEL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;

STEEL_CONST short tgp_size = WM * WN * 32;

using loader_a_t = BlockLoader<
T,
transpose_a ? BK : BM,
transpose_a ? BM : BK,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
!transpose_a,
tgp_size>;
using loader_b_t = BlockLoader<
T,
transpose_b ? BN : BK,
transpose_b ? BK : BN,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
transpose_b,
tgp_size>;
using mma_t = BlockMMA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
AccumType,
Epilogue>;

/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
thread mma_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
thread const short& lbk,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;

thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];

if (!M_aligned) {
short2 tile_dims_A =
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
loader_a.set_mask(tile_dims_A, mask_A);
}

if (!N_aligned) {
short2 tile_dims_B =
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
loader_b.set_mask(tile_dims_B, mask_B);
}

for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(mask_A);
}

if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(mask_B);
}

threadgroup_barrier(mem_flags::mem_threadgroup);

// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);

// Prepare for next iteration
loader_a.next();
loader_b.next();
}

if (!K_aligned_) {
threadgroup_barrier(mem_flags::mem_threadgroup);

short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);

loader_a.set_mask(tile_dims_A_last, mask_A);
loader_b.set_mask(tile_dims_B_last, mask_B);

loader_a.load_safe(mask_A);
loader_b.load_safe(mask_B);

threadgroup_barrier(mem_flags::mem_threadgroup);

mma_op.mma(As, Bs);
}
}

/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* C [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;

const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;

if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}

threadgroup_barrier(mem_flags::mem_none);

// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;

A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
C += c_row * params->ldc + c_col;

// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);

// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);

int gemm_k_iterations = params->gemm_k_iterations_aligned;

///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();

threadgroup_barrier(mem_flags::mem_threadgroup);

// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);

// Prepare for next iteration
loader_a.next();
loader_b.next();
}

threadgroup_barrier(mem_flags::mem_none);

// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);

thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size];
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size];

loader_a.set_mask(tile_dims_A, mask_A);
loader_b.set_mask(tile_dims_B, mask_B);

loader_a.load_safe(mask_A);
loader_b.load_safe(mask_B);

threadgroup_barrier(mem_flags::mem_threadgroup);

mma_op.mma(As, Bs);
}

// Store results to device memory
mma_op.store_result(C, params->ldc);
return;

}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;

if (tgp_bm == BM && tgp_bn == BN) {
gemm_loop<true, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);

mma_op.store_result(C, params->ldc);
return;

} else if (tgp_bn == BN) {
gemm_loop<false, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);

mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;

} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);

mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;

} else {
gemm_loop<false, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);

mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;
}
}
}
};

} // namespace steel
} // namespace mlx
Loading

0 comments on commit 7bd8de2

Please sign in to comment.