Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Extended Gemm interface to support mixed precision operations (#500)
Browse files Browse the repository at this point in the history
Co-authored-by: pgorlani <92453485+pgorlani@users.noreply.github.com>
Co-authored-by: HJA Bird <hja.bird@gmail.com>
Co-authored-by: nscipione <nicolo.scipione@codeplay.com>
  • Loading branch information
4 people authored May 15, 2024
1 parent 5b80c99 commit 3a3113a
Show file tree
Hide file tree
Showing 21 changed files with 539 additions and 390 deletions.
233 changes: 126 additions & 107 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -152,32 +152,42 @@ function(generate_blas_objects blas_level func)
list(APPEND data_list_c "half")
endif()
endif()
foreach(data ${data_list_c})
cpp_type(cpp_data ${data})
foreach(index ${index_list})
foreach(increment ${index_list})
sanitize_file_name(file_name
"${func}_${data}_${index}_${data}_${increment}.cpp")
add_custom_command(OUTPUT "${LOCATION}/${file_name}"
COMMAND ${PYTHON_EXECUTABLE} ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_ops.py
${PROJECT_SOURCE_DIR}/external/
${PORTBLAS_SRC_GENERATOR}/gen
${blas_level}
${func}
${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
${cpp_data}
${index}
${increment}
${file_name}
MAIN_DEPENDENCY ${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
DEPENDS ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_ops.py
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
VERBATIM
)
list(APPEND FUNC_SRC "${LOCATION}/${file_name}")
endforeach(increment)
endforeach(index)
endforeach(data)
foreach(data_in ${data_list_c})
set(data_list_out ${data_in})
# When using half with Gemm target, generate a mixed-precision
# Gemm kernel (half-float) alongside the fully half based kernel.
if((data_in STREQUAL "half") AND (${func} STREQUAL "gemm"))
list(APPEND data_list_out "float")
endif()
cpp_type(cpp_data_in ${data_in})
foreach(data_out ${data_list_out})
cpp_type(cpp_data_out ${data_out})
foreach(index ${index_list})
foreach(increment ${index_list})
sanitize_file_name(file_name
"${func}_${data_in}_${index}_${data_out}_${increment}.cpp")
add_custom_command(OUTPUT "${LOCATION}/${file_name}"
COMMAND ${PYTHON_EXECUTABLE} ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_ops.py
${PROJECT_SOURCE_DIR}/external/
${PORTBLAS_SRC_GENERATOR}/gen
${blas_level}
${func}
${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
${cpp_data_in}
${cpp_data_out}
${index}
${increment}
${file_name}
MAIN_DEPENDENCY ${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
DEPENDS ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_ops.py
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
VERBATIM
)
list(APPEND FUNC_SRC "${LOCATION}/${file_name}")
endforeach(increment)
endforeach(index)
endforeach(data_out ${data_list_out})
endforeach(data_in)
add_library(${func} OBJECT ${FUNC_SRC})
set_target_compile_def(${func})
target_include_directories(${func} PRIVATE ${PORTBLAS_SRC} ${PORTBLAS_INCLUDE}
Expand Down Expand Up @@ -312,87 +322,97 @@ function(add_gemm_configuration
if(const_pos)
string(REPLACE "_const" "" actualfunc ${func})
endif()
# When using half data type, generate a mixed-precision Gemm
# configuration (half-float) alongside the fully half based one.
set(data_list_out ${data})
if(data STREQUAL "half")
list(APPEND data_list_out "float")
endif()
cpp_type(cpp_data ${data})
foreach(symm_a ${boolean_list})
foreach(symm_b ${boolean_list})
if ((${data} MATCHES "half") AND (symm_a OR symm_b))
continue()
endif()
if (symm_a AND symm_b)
continue()
endif()
foreach(trans_a ${boolean_list})
foreach(trans_b ${boolean_list})
if ((symm_a AND trans_b) OR (symm_b AND trans_a))
continue()
endif()
foreach(is_beta_zero ${boolean_list})
foreach(index ${index_list})
set(file_name "${func}_${double_buffer}_${conflict_a}_"
"${conflict_b}_${trans_a}_${trans_b}_"
"${is_beta_zero}_${gemm_memory_type}_"
"${gemm_shape_type}_${gemm_vectorize_type}_"
"${vector_size}_${batch_type}_${use_joint_matrix}_"
"${data}_${index}_${tir}_${tic}_${twr}_"
"${twc}_${tsr}_${tsc}_${tlr}_${tlc}_"
"${item_batch}_${wg_batch}_${symm_a}_${symm_b}_"
"${jm_m}_${jm_n}_${jm_k}_${jm_in_type}_${jm_out_type}_"
"${wg_size}_${cache_line_size}_${data}.cpp")
sanitize_file_name(file_name "${file_name}")
add_custom_command(OUTPUT "${LOCATION}/${file_name}"
COMMAND ${PYTHON_EXECUTABLE} ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_gemm_launcher.py
${PROJECT_SOURCE_DIR}/external/
${PORTBLAS_SRC_GENERATOR}/gen
${blas_level}
${func}
${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
${cpp_data}
${index}
${double_buffer}
${conflict_a}
${conflict_b}
${trans_a}
${trans_b}
${is_beta_zero}
${gemm_memory_type}
${gemm_shape_type}
${tir}
${tic}
${twr}
${twc}
${tsr}
${tsc}
${tlr}
${tlc}
${item_batch}
${wg_batch}
${jm_m}
${jm_n}
${jm_k}
${jm_in_type}
${jm_out_type}
${wg_size}
${cache_line_size}
${file_name}
${gemm_vectorize_type}
${vector_size}
${batch_type}
${use_joint_matrix}
${symm_a}
${symm_b}
MAIN_DEPENDENCY ${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
DEPENDS ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_gemm_launcher.py
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
VERBATIM
)
list(APPEND gemm_sources "${LOCATION}/${file_name}")
set(gemm_sources "${gemm_sources}" PARENT_SCOPE)
endforeach(index)
endforeach(is_beta_zero)
endforeach(trans_b)
endforeach(trans_a)
endforeach(symm_b)
endforeach(symm_a)
foreach(data_out ${data_list_out})
cpp_type(cpp_data_out ${data_out})
foreach(symm_a ${boolean_list})
foreach(symm_b ${boolean_list})
if ((${data} MATCHES "half") AND (symm_a OR symm_b))
continue()
endif()
if (symm_a AND symm_b)
continue()
endif()
foreach(trans_a ${boolean_list})
foreach(trans_b ${boolean_list})
if ((symm_a AND trans_b) OR (symm_b AND trans_a))
continue()
endif()
foreach(is_beta_zero ${boolean_list})
foreach(index ${index_list})
set(file_name "${func}_${double_buffer}_${conflict_a}_"
"${conflict_b}_${trans_a}_${trans_b}_"
"${is_beta_zero}_${gemm_memory_type}_"
"${gemm_shape_type}_${gemm_vectorize_type}_"
"${vector_size}_${batch_type}_${use_joint_matrix}_"
"${index}_${tir}_${tic}_${twr}_"
"${twc}_${tsr}_${tsc}_${tlr}_${tlc}_"
"${item_batch}_${wg_batch}_${symm_a}_${symm_b}_"
"${jm_m}_${jm_n}_${jm_k}_${jm_in_type}_${jm_out_type}_"
"${wg_size}_${cache_line_size}_${data}_${data_out}.cpp")
sanitize_file_name(file_name "${file_name}")
add_custom_command(OUTPUT "${LOCATION}/${file_name}"
COMMAND ${PYTHON_EXECUTABLE} ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_gemm_launcher.py
${PROJECT_SOURCE_DIR}/external/
${PORTBLAS_SRC_GENERATOR}/gen
${blas_level}
${func}
${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
${cpp_data}
${index}
${double_buffer}
${conflict_a}
${conflict_b}
${trans_a}
${trans_b}
${is_beta_zero}
${gemm_memory_type}
${gemm_shape_type}
${tir}
${tic}
${twr}
${twc}
${tsr}
${tsc}
${tlr}
${tlc}
${item_batch}
${wg_batch}
${jm_m}
${jm_n}
${jm_k}
${jm_in_type}
${jm_out_type}
${wg_size}
${cache_line_size}
${file_name}
${gemm_vectorize_type}
${vector_size}
${batch_type}
${use_joint_matrix}
${symm_a}
${symm_b}
${cpp_data_out}
MAIN_DEPENDENCY ${PORTBLAS_SRC}/interface/${blas_level}/${func}.cpp.in
DEPENDS ${PORTBLAS_SRC_GENERATOR}/py_gen_blas_gemm_launcher.py
WORKING_DIRECTORY ${PROJECT_BINARY_DIR}
VERBATIM
)
list(APPEND gemm_sources "${LOCATION}/${file_name}")
set(gemm_sources "${gemm_sources}" PARENT_SCOPE)
endforeach(index)
endforeach(is_beta_zero)
endforeach(trans_b)
endforeach(trans_a)
endforeach(symm_b)
endforeach(symm_a)
endforeach(data_out)
endfunction()
if(${TUNING_TARGET} STREQUAL "INTEL_GPU")
set(supported_types
Expand Down Expand Up @@ -702,7 +722,6 @@ else() # default cpu backend
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false" "false")

if(BLAS_ENABLE_HALF)
add_gemm_configuration(
"half" 128 "false" "false" "false"
Expand Down
4 changes: 2 additions & 2 deletions include/operations/blas3_trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ struct Tile {
* @tparam TransB iff true, matrix B will be transposed on the fly
* @tparam SymmA whether the matrix A is a symmetric triangular matrix
* @tparam SymmB whether the matrix B is a symmetric triangular matrix
* @tparam element_t type of matrix elements
* @tparam element_t type of scalar alpha & beta
* @tparam UseJointMatrix boolean parameter to decide whether to use
* joint_matrix or not
* @param a_ the lhs_t matrix
Expand All @@ -195,7 +195,7 @@ template <typename input_t, typename output_t, bool DoubleBuffer, bool NbcA,
int VectorSize, int BatchType, bool UseJointMatrix = false>
class Gemm {
public:
using value_t = element_t;
using value_t = typename input_t::value_t;
using index_t = typename std::make_signed<typename input_t::index_t>::type;
static constexpr int wg_size = tile_type::wg_rows * tile_type::wg_cols;
static constexpr bool trans_a = TransA;
Expand Down
12 changes: 9 additions & 3 deletions python_generator/py_gen_blas_gemm_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
blas_level_name = sys.argv[3]
blas_function_name = sys.argv[4]
blas_template_impl = sys.argv[5]
data = sys.argv[6]
data_in = sys.argv[6]
index = sys.argv[7]
double_buffer = sys.argv[8]
conflict_a = sys.argv[9]
Expand Down Expand Up @@ -72,6 +72,7 @@
use_joint_matrix = sys.argv[37]
symm_a = sys.argv[38]
symm_b = sys.argv[39]
data_out = sys.argv[40] # Different from data_in for mixed-precision cases
source = 'generated_src/' + blas_level_name + '/' + blas_function_name + '/'
try:
os.makedirs(source)
Expand Down Expand Up @@ -208,8 +209,13 @@
itermode=Itermode.combinations,
iter_modifier=1),
Iterable(
key='DATA_TYPE',
vals=[data],
key='DATA_TYPE_IN',
vals=[data_in],
itermode=Itermode.combinations,
iter_modifier=1),
Iterable(
key='DATA_TYPE_OUT',
vals=[data_out],
itermode=Itermode.combinations,
iter_modifier=1),
Iterable(
Expand Down
26 changes: 21 additions & 5 deletions python_generator/py_gen_blas_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@
blas_level_name = sys.argv[3]
blas_function_name = sys.argv[4]
blas_template_impl = sys.argv[5]
data = sys.argv[6]
index = sys.argv[7]
increment = sys.argv[8]
file_name = sys.argv[9]
data_in = sys.argv[6]
data_out = sys.argv[7]
index = sys.argv[8]
increment = sys.argv[9]
file_name = sys.argv[10]
source = 'generated_src/' + blas_level_name + '/' + blas_function_name + '/'

try:
Expand All @@ -58,7 +59,7 @@
iterables = [
Iterable(
key='DATA_TYPE',
vals=[data],
vals=[data_in],
itermode=Itermode.combinations,
iter_modifier=1),
Iterable(
Expand All @@ -72,6 +73,21 @@
itermode=Itermode.combinations,
iter_modifier=1)
]

# Gemm supports mixed-precision inputs/outputs/arithmetics
is_gemm: bool = blas_function_name == "gemm"
if is_gemm:
iterables.append(Iterable(
key='DATA_TYPE_IN',
vals=[data_in],
itermode=Itermode.combinations,
iter_modifier=1))
iterables.append(Iterable(
key='DATA_TYPE_OUT',
vals=[data_out],
itermode=Itermode.combinations,
iter_modifier=1))

iter_groups = [IterGroup('@ip1@', template, iterables, combine_iters=True)]
generate_file(
input_template,
Expand Down
6 changes: 4 additions & 2 deletions src/interface/blas3/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
using element_in_t = typename ValueType<container_0_t>::type;
// Unused configuration cases
if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) {
return _dependencies;
Expand All @@ -53,7 +54,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
const auto n_elem_access = (_M * _K + _K * _N + _M * _N);
const auto arith_ratio = n_fma / n_elem_access;
static constexpr int ClSize = 64;
static constexpr int tileWgSize = ClSize / sizeof(element_t);
static constexpr int tileWgSize = ClSize / sizeof(element_in_t);
if (batch_type == gemm_batch_type_t::interleaved) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
Expand Down Expand Up @@ -242,8 +243,9 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
using element_in_t = typename ValueType<container_0_t>::type;
static constexpr int ClSize = 64;
static constexpr int tileWgSize = ClSize / sizeof(element_t);
static constexpr int tileWgSize = ClSize / sizeof(element_in_t);
/* Tall & Skinny matrices. */
#ifdef GEMM_TALL_SKINNY_SUPPORT
if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8)) {
Expand Down
Loading

0 comments on commit 3a3113a

Please sign in to comment.