-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
1,521 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
// Copyright © 2024 Apple Inc. | ||
|
||
#pragma once | ||
|
||
#include "loader.h" | ||
#include "mma.h" | ||
#include "transforms.h" | ||
#include "../utils.h" | ||
|
||
using namespace metal; | ||
|
||
/////////////////////////////////////////////////////////////////////////////// | ||
// GEMM kernel class | ||
/////////////////////////////////////////////////////////////////////////////// | ||
|
||
namespace mlx { | ||
namespace steel { | ||
|
||
template <bool M_aligned, bool N_aligned, bool K_aligned> | ||
struct LoopAlignment {}; | ||
|
||
template < | ||
typename T, | ||
typename U, | ||
int BM, | ||
int BN, | ||
int BK, | ||
int WM, | ||
int WN, | ||
bool transpose_a, | ||
bool transpose_b, | ||
bool MN_aligned, | ||
bool K_aligned, | ||
typename AccumType = typename AccumHelper<T>::accum_type, | ||
typename Epilogue = TransformNone<U, AccumType>> | ||
struct GEMMKernel { | ||
STEEL_CONST short tgp_padding_a = 16 / sizeof(T); | ||
STEEL_CONST short tgp_padding_b = 16 / sizeof(T); | ||
STEEL_CONST short tgp_mem_size_a = | ||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); | ||
STEEL_CONST short tgp_mem_size_b = | ||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); | ||
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; | ||
|
||
STEEL_CONST short tgp_size = WM * WN * 32; | ||
|
||
using loader_a_t = BlockLoader< | ||
T, | ||
transpose_a ? BK : BM, | ||
transpose_a ? BM : BK, | ||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, | ||
!transpose_a, | ||
tgp_size>; | ||
using loader_b_t = BlockLoader< | ||
T, | ||
transpose_b ? BN : BK, | ||
transpose_b ? BK : BN, | ||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, | ||
transpose_b, | ||
tgp_size>; | ||
using mma_t = BlockMMA< | ||
T, | ||
U, | ||
BM, | ||
BN, | ||
BK, | ||
WM, | ||
WN, | ||
transpose_a, | ||
transpose_b, | ||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, | ||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, | ||
AccumType, | ||
Epilogue>; | ||
|
||
/* Main kernel function */ | ||
template <bool M_aligned, bool N_aligned, bool K_aligned_> | ||
static METAL_FUNC void gemm_loop( | ||
threadgroup T* As [[threadgroup(0)]], | ||
threadgroup T* Bs [[threadgroup(1)]], | ||
const int gemm_k_iterations, | ||
thread loader_a_t& loader_a, | ||
thread loader_b_t& loader_b, | ||
thread mma_t& mma_op, | ||
thread const short& tgp_bm, | ||
thread const short& tgp_bn, | ||
thread const short& lbk, | ||
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) { | ||
// Appease the compiler | ||
(void)l; | ||
|
||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size]; | ||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size]; | ||
|
||
if (!M_aligned) { | ||
short2 tile_dims_A = | ||
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); | ||
loader_a.set_mask(tile_dims_A, mask_A); | ||
} | ||
|
||
if (!N_aligned) { | ||
short2 tile_dims_B = | ||
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); | ||
loader_b.set_mask(tile_dims_B, mask_B); | ||
} | ||
|
||
for (int k = 0; k < gemm_k_iterations; k++) { | ||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||
// Load elements into threadgroup | ||
if (M_aligned) { | ||
loader_a.load_unsafe(); | ||
} else { | ||
loader_a.load_safe(mask_A); | ||
} | ||
|
||
if (N_aligned) { | ||
loader_b.load_unsafe(); | ||
} else { | ||
loader_b.load_safe(mask_B); | ||
} | ||
|
||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||
|
||
// Multiply and accumulate threadgroup elements | ||
mma_op.mma(As, Bs); | ||
|
||
// Prepare for next iteration | ||
loader_a.next(); | ||
loader_b.next(); | ||
} | ||
|
||
if (!K_aligned_) { | ||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||
|
||
short2 tile_dims_A_last = | ||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); | ||
short2 tile_dims_B_last = | ||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); | ||
|
||
loader_a.set_mask(tile_dims_A_last, mask_A); | ||
loader_b.set_mask(tile_dims_B_last, mask_B); | ||
|
||
loader_a.load_safe(mask_A); | ||
loader_b.load_safe(mask_B); | ||
|
||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||
|
||
mma_op.mma(As, Bs); | ||
} | ||
} | ||
|
||
/* Main kernel function */ | ||
static METAL_FUNC void run( | ||
const device T* A [[buffer(0)]], | ||
const device T* B [[buffer(1)]], | ||
device U* C [[buffer(2)]], | ||
const constant GEMMParams* params [[buffer(3)]], | ||
threadgroup T* As [[threadgroup(0)]], | ||
threadgroup T* Bs [[threadgroup(1)]], | ||
uint simd_lane_id [[thread_index_in_simdgroup]], | ||
uint simd_group_id [[simdgroup_index_in_threadgroup]], | ||
uint3 tid [[threadgroup_position_in_grid]], | ||
uint3 lid [[thread_position_in_threadgroup]]) { | ||
// Pacifying compiler | ||
(void)lid; | ||
|
||
const int tid_y = ((tid.y) << params->swizzle_log) + | ||
((tid.x) & ((1 << params->swizzle_log) - 1)); | ||
const int tid_x = (tid.x) >> params->swizzle_log; | ||
|
||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { | ||
return; | ||
} | ||
|
||
threadgroup_barrier(mem_flags::mem_none); | ||
|
||
// Find block in A, B, C | ||
const int c_row = tid_y * BM; | ||
const int c_col = tid_x * BN; | ||
|
||
A += transpose_a ? c_row : c_row * params->lda; | ||
B += transpose_b ? c_col * params->ldb : c_col; | ||
C += c_row * params->ldc + c_col; | ||
|
||
// Prepare threadgroup loading operations | ||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); | ||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); | ||
|
||
// Prepare threadgroup mma operation | ||
thread mma_t mma_op(simd_group_id, simd_lane_id); | ||
|
||
int gemm_k_iterations = params->gemm_k_iterations_aligned; | ||
|
||
/////////////////////////////////////////////////////////////////////////////// | ||
// MNK aligned loop | ||
if (MN_aligned) { | ||
for (int k = 0; k < gemm_k_iterations; k++) { | ||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||
// Load elements into threadgroup | ||
loader_a.load_unsafe(); | ||
loader_b.load_unsafe(); | ||
|
||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||
|
||
// Multiply and accumulate threadgroup elements | ||
mma_op.mma(As, Bs); | ||
|
||
// Prepare for next iteration | ||
loader_a.next(); | ||
loader_b.next(); | ||
} | ||
|
||
threadgroup_barrier(mem_flags::mem_none); | ||
|
||
// Loop tail | ||
if (!K_aligned) { | ||
int lbk = params->K - params->gemm_k_iterations_aligned * BK; | ||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); | ||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); | ||
|
||
thread bool mask_A[loader_a_t::n_rows][loader_a_t::vec_size]; | ||
thread bool mask_B[loader_b_t::n_rows][loader_b_t::vec_size]; | ||
|
||
loader_a.set_mask(tile_dims_A, mask_A); | ||
loader_b.set_mask(tile_dims_B, mask_B); | ||
|
||
loader_a.load_safe(mask_A); | ||
loader_b.load_safe(mask_B); | ||
|
||
threadgroup_barrier(mem_flags::mem_threadgroup); | ||
|
||
mma_op.mma(As, Bs); | ||
} | ||
|
||
// Store results to device memory | ||
mma_op.store_result(C, params->ldc); | ||
return; | ||
|
||
} | ||
/////////////////////////////////////////////////////////////////////////////// | ||
// MN unaligned loop | ||
else { // Loop over K - unaligned case | ||
short tgp_bm = min(BM, params->M - c_row); | ||
short tgp_bn = min(BN, params->N - c_col); | ||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; | ||
|
||
if (tgp_bm == BM && tgp_bn == BN) { | ||
gemm_loop<true, true, K_aligned>( | ||
As, | ||
Bs, | ||
gemm_k_iterations, | ||
loader_a, | ||
loader_b, | ||
mma_op, | ||
tgp_bm, | ||
tgp_bn, | ||
leftover_bk); | ||
|
||
mma_op.store_result(C, params->ldc); | ||
return; | ||
|
||
} else if (tgp_bn == BN) { | ||
gemm_loop<false, true, K_aligned>( | ||
As, | ||
Bs, | ||
gemm_k_iterations, | ||
loader_a, | ||
loader_b, | ||
mma_op, | ||
tgp_bm, | ||
tgp_bn, | ||
leftover_bk); | ||
|
||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); | ||
return; | ||
|
||
} else if (tgp_bm == BM) { | ||
gemm_loop<true, false, K_aligned>( | ||
As, | ||
Bs, | ||
gemm_k_iterations, | ||
loader_a, | ||
loader_b, | ||
mma_op, | ||
tgp_bm, | ||
tgp_bn, | ||
leftover_bk); | ||
|
||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); | ||
return; | ||
|
||
} else { | ||
gemm_loop<false, false, K_aligned>( | ||
As, | ||
Bs, | ||
gemm_k_iterations, | ||
loader_a, | ||
loader_b, | ||
mma_op, | ||
tgp_bm, | ||
tgp_bn, | ||
leftover_bk); | ||
|
||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); | ||
return; | ||
} | ||
} | ||
} | ||
}; | ||
|
||
} // namespace steel | ||
} // namespace mlx |
Oops, something went wrong.