From 5728976f45d04edceeb6ee4394d9ae88e2fb2dd9 Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Thu, 15 Aug 2024 18:23:39 -0400 Subject: [PATCH] Move some code around. --- lib/nnc/mfa/v2/GEMMKernel.cpp | 162 ++++++++++++++++++---------------- 1 file changed, 84 insertions(+), 78 deletions(-) diff --git a/lib/nnc/mfa/v2/GEMMKernel.cpp b/lib/nnc/mfa/v2/GEMMKernel.cpp index 637d51e9e..96cf87a33 100644 --- a/lib/nnc/mfa/v2/GEMMKernel.cpp +++ b/lib/nnc/mfa/v2/GEMMKernel.cpp @@ -229,6 +229,8 @@ GEMMKernel::GEMMKernel(GEMMKernelDescriptor descriptor, MTL::Device *const devic } } +#pragma mark - Source + std::string GEMMKernel::createSource() const noexcept { CodeWriter source; @@ -496,6 +498,8 @@ for (ushort m = 0; m < {{REGISTER_M}}; m += 8) { *source += createMultiply; } +#pragma mark - Caching + void GEMMKernel::createInitializeC(CodeWriter *source) const noexcept { source->SetValue("REGISTER_M_8_REGISTER_N_8", std::to_string((registerM / 8) * (registerN / 8))); *source += R"( @@ -733,6 +737,86 @@ if ({{DIRECT_ACCESS_CONDITION}}) { )"; } +void GEMMKernel::createStoreC(CodeWriter *source) const noexcept { + if (memoryPrecisions.C == GEMMOperandPrecision::BF16 && registerPrecisions.C == GEMMOperandPrecision::FP32) { + source->SetValue("STORE_FUNCTION_C", "store_bfloat"); + } else { + source->SetValue("STORE_FUNCTION_C", "store"); + } + + *source += R"( + +if ({{DIRECT_ACCESS_CONDITION}}) { + // Fast path for matrices that qualify. + uint2 C_offset(N_offset + offset_in_group.x, + M_offset + offset_in_group.y); + auto C_dst = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( + C, {{LEADING_DIMENSION_C}}, C_offset); + + // Write the accumulator to device memory. +#pragma clang loop unroll(full) + for (ushort m = 0; m < {{REGISTER_M}}; m += 8) { +#pragma clang loop unroll(full) + for (ushort n = 0; n < {{REGISTER_N}}; n += 8) { + ushort2 origin(n, m); + auto C = get_sram(C_sram, {{REGISTER_N}}, origin); + C->{{STORE_FUNCTION_C}}(C_dst, {{LEADING_DIMENSION_C}}, origin); + } + } +} else { + // Slow path for when memory must be handled more carefully. + auto C_block = (threadgroup {{MEMORY_NAME_C}}*)(threadgroup_block); + auto C_block_dst = + simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( + C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, offset_in_group); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write the accumulator to threadgroup memory. +#pragma clang loop unroll(full) + for (ushort m = 0; m < {{REGISTER_M}}; m += 8) { +#pragma clang loop unroll(full) + for (ushort n = 0; n < {{REGISTER_N}}; n += 8) { + ushort2 origin(n, m); + auto C = get_sram(C_sram, {{REGISTER_N}}, origin); + C->{{STORE_FUNCTION_C}}( + C_block_dst, {{LEADING_BLOCK_DIMENSIONS_C}}, origin); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Launch the async copy from threadgroup to device memory. + if (sidx == 0) { + uint2 C_offset(gid.x * N_group, gid.y * M_group); + ushort2 C_tile(min(uint(N_group), N - C_offset.x), + min(uint(M_group), M - C_offset.y)); + auto C_dst = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( + C, {{LEADING_DIMENSION_C}}, C_offset); + + // If we shift successfully, the garbage zone moves from the bottom right + // to the top left. + if ((M_shift != 0) || (N_shift != 0)) { + ushort2 C_block_shift(0, 0); + if ((M_shift != 0) && (C_offset.y >= M_edge)) { + C_block_shift.y = M_shift; + } + if ((N_shift != 0) && (C_offset.x >= N_edge)) { + C_block_shift.x = N_shift; + } + C_block = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( + C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, C_block_shift); + } + + simdgroup_event event; + event.async_copy( + C_dst, {{LEADING_DIMENSION_C}}, C_tile, + C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, C_tile); + } +} +)"; +} + +#pragma mark - Multiply + void GEMMKernel::createMultiplyIterations(CodeWriter *source) const noexcept { if (preferAsyncLoad) { source->SetValue("ASYNC_ITERATIONS_START", "0"); @@ -840,81 +924,3 @@ for (uint k = {{ASYNC_ITERATIONS_START}}; k < K; k += K_group) { )"; } - -void GEMMKernel::createStoreC(CodeWriter *source) const noexcept { - if (memoryPrecisions.C == GEMMOperandPrecision::BF16 && registerPrecisions.C == GEMMOperandPrecision::FP32) { - source->SetValue("STORE_FUNCTION_C", "store_bfloat"); - } else { - source->SetValue("STORE_FUNCTION_C", "store"); - } - - *source += R"( - -if ({{DIRECT_ACCESS_CONDITION}}) { - // Fast path for matrices that qualify. - uint2 C_offset(N_offset + offset_in_group.x, - M_offset + offset_in_group.y); - auto C_dst = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( - C, {{LEADING_DIMENSION_C}}, C_offset); - - // Write the accumulator to device memory. -#pragma clang loop unroll(full) - for (ushort m = 0; m < {{REGISTER_M}}; m += 8) { -#pragma clang loop unroll(full) - for (ushort n = 0; n < {{REGISTER_N}}; n += 8) { - ushort2 origin(n, m); - auto C = get_sram(C_sram, {{REGISTER_N}}, origin); - C->{{STORE_FUNCTION_C}}(C_dst, {{LEADING_DIMENSION_C}}, origin); - } - } -} else { - // Slow path for when memory must be handled more carefully. - auto C_block = (threadgroup {{MEMORY_NAME_C}}*)(threadgroup_block); - auto C_block_dst = - simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( - C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, offset_in_group); - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write the accumulator to threadgroup memory. -#pragma clang loop unroll(full) - for (ushort m = 0; m < {{REGISTER_M}}; m += 8) { -#pragma clang loop unroll(full) - for (ushort n = 0; n < {{REGISTER_N}}; n += 8) { - ushort2 origin(n, m); - auto C = get_sram(C_sram, {{REGISTER_N}}, origin); - C->{{STORE_FUNCTION_C}}( - C_block_dst, {{LEADING_BLOCK_DIMENSIONS_C}}, origin); - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Launch the async copy from threadgroup to device memory. - if (sidx == 0) { - uint2 C_offset(gid.x * N_group, gid.y * M_group); - ushort2 C_tile(min(uint(N_group), N - C_offset.x), - min(uint(M_group), M - C_offset.y)); - auto C_dst = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( - C, {{LEADING_DIMENSION_C}}, C_offset); - - // If we shift successfully, the garbage zone moves from the bottom right - // to the top left. - if ((M_shift != 0) || (N_shift != 0)) { - ushort2 C_block_shift(0, 0); - if ((M_shift != 0) && (C_offset.y >= M_edge)) { - C_block_shift.y = M_shift; - } - if ((N_shift != 0) && (C_offset.x >= N_edge)) { - C_block_shift.x = N_shift; - } - C_block = simdgroup_matrix_storage<{{MEMORY_NAME_C}}>::apply_offset( - C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, C_block_shift); - } - - simdgroup_event event; - event.async_copy( - C_dst, {{LEADING_DIMENSION_C}}, C_tile, - C_block, {{LEADING_BLOCK_DIMENSIONS_C}}, C_tile); - } -} -)"; -}