Skip to content

Commit bfe279e

Browse files
committed
Check matmul types and error at compile-time if the backend doesn't support them
1 parent 6951f04 commit bfe279e

File tree

1 file changed

+45
-2
lines changed

1 file changed

+45
-2
lines changed

include/matx/transforms/matmul.h

+45-2
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,45 @@ union MatMulScaleType_t {
7777
double cf64[2];
7878
};
7979

80+
template <typename OpA, typename OpB, typename OpC, MatXMatMulProvider_t PROV = PROVIDER_TYPE_CUBLASLT>
81+
constexpr bool CompatibleGemmTypes() {
82+
if constexpr (!std::is_same_v<typename OpA::scalar_type, typename OpB::scalar_type> &&
83+
!std::is_same_v<typename OpB::scalar_type, typename OpC::scalar_type> &&
84+
!std::is_same_v<typename OpA::scalar_type, typename OpC::scalar_type>) {
85+
return false;
86+
}
87+
88+
if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) {
89+
if constexpr (std::is_same_v<typename OpA::scalar_type, typename OpB::scalar_type> &&
90+
std::is_same_v<typename OpB::scalar_type, typename OpC::scalar_type>) {
91+
// List of accepted types when A/B/C match
92+
return std::is_same_v<typename OpA::scalar_type, matxFp16> ||
93+
std::is_same_v<typename OpA::scalar_type, matxBf16> ||
94+
std::is_same_v<typename OpA::scalar_type, float> ||
95+
std::is_same_v<typename OpA::scalar_type, double> ||
96+
std::is_same_v<typename OpA::scalar_type, cuda::std::complex<float>> ||
97+
std::is_same_v<typename OpA::scalar_type, cuda::std::complex<double>> ||
98+
std::is_same_v<typename OpA::scalar_type, int8_t> ||
99+
std::is_same_v<typename OpA::scalar_type, matxFp16Complex> ||
100+
std::is_same_v<typename OpA::scalar_type, matxBf16Complex>;
101+
102+
}
103+
// Accumulator type different from A/B
104+
else if constexpr ( std::is_same_v<typename OpA::scalar_type, typename OpB::scalar_type> &&
105+
!std::is_same_v<typename OpB::scalar_type, typename OpC::scalar_type>) {
106+
return (std::is_same_v<typename OpA::scalar_type, int8_t> && std::is_same_v<typename OpC::scalar_type, int32_t>) ||
107+
(std::is_same_v<typename OpA::scalar_type, int8_t> && std::is_same_v<typename OpC::scalar_type, float>) ||
108+
(std::is_same_v<typename OpA::scalar_type, matxBf16> && std::is_same_v<typename OpC::scalar_type, float>) ||
109+
(std::is_same_v<typename OpA::scalar_type, matxFp16> && std::is_same_v<typename OpC::scalar_type, float>) ||
110+
(std::is_same_v<typename OpA::scalar_type, int8_t> && std::is_same_v<typename OpC::scalar_type, float>);
111+
}
112+
}
113+
else {
114+
// For now return true for other providers until we support more
115+
return true;
116+
}
117+
}
118+
80119
/**
81120
* Parameters needed to execute a GEMM. For the most part, these are very
82121
* similar to that of a standard GEMM call
@@ -834,7 +873,7 @@ class matxMatMulHandle_t {
834873
static_cast<int>(
835874
params_.ldc)}, // Tensor-ref for destination matrix D (may be
836875
// different memory than source C matrix)
837-
{alpha, beta}); // Scalars used in the Epilogue
876+
{static_cast<T1>(alpha), static_cast<T1>(beta)}); // Scalars used in the Epilogue
838877

839878
CutlassGemm gemm_operator;
840879
cutlass::Status status = gemm_operator(args, nullptr, stream);
@@ -895,7 +934,7 @@ class matxMatMulHandle_t {
895934
params_.ldc)}, // Tensor-ref for destination matrix D (may
896935
// be different memory than source C matrix)
897936
c_adj.Stride(RANK - 3), // Batch Stride C
898-
{alpha, beta},
937+
{static_cast<T1>(alpha), static_cast<T1>(beta)},
899938
params_.batch // Batch Dimension
900939
); // Scalars used in the Epilogue
901940

@@ -1118,6 +1157,10 @@ void matmul_impl(TensorTypeC C, const TensorTypeA A,
11181157
auto A_ = as_type<typename TensorTypeC::scalar_type>(A);
11191158
auto B_ = as_type<typename TensorTypeC::scalar_type>(B);
11201159

1160+
static_assert(detail::CompatibleGemmTypes<decltype(A_), decltype(B_), TensorTypeC, PROV>(),
1161+
"Combination of A/B/C types are not supported");
1162+
1163+
11211164
// CublasLt does not support operators and certain transpose modes.
11221165
// Grab a suppported tensor here and copy in if necessary.
11231166
auto c = getCublasSupportedTensor(C, stream);

0 commit comments

Comments
 (0)