@@ -77,6 +77,45 @@ union MatMulScaleType_t {
77
77
double cf64[2 ];
78
78
};
79
79
80
+ template <typename OpA, typename OpB, typename OpC, MatXMatMulProvider_t PROV = PROVIDER_TYPE_CUBLASLT>
81
+ constexpr bool CompatibleGemmTypes () {
82
+ if constexpr (!std::is_same_v<typename OpA::scalar_type, typename OpB::scalar_type> &&
83
+ !std::is_same_v<typename OpB::scalar_type, typename OpC::scalar_type> &&
84
+ !std::is_same_v<typename OpA::scalar_type, typename OpC::scalar_type>) {
85
+ return false ;
86
+ }
87
+
88
+ if constexpr (PROV == PROVIDER_TYPE_CUBLASLT) {
89
+ if constexpr (std::is_same_v<typename OpA::scalar_type, typename OpB::scalar_type> &&
90
+ std::is_same_v<typename OpB::scalar_type, typename OpC::scalar_type>) {
91
+ // List of accepted types when A/B/C match
92
+ return std::is_same_v<typename OpA::scalar_type, matxFp16> ||
93
+ std::is_same_v<typename OpA::scalar_type, matxBf16> ||
94
+ std::is_same_v<typename OpA::scalar_type, float > ||
95
+ std::is_same_v<typename OpA::scalar_type, double > ||
96
+ std::is_same_v<typename OpA::scalar_type, cuda::std::complex<float >> ||
97
+ std::is_same_v<typename OpA::scalar_type, cuda::std::complex<double >> ||
98
+ std::is_same_v<typename OpA::scalar_type, int8_t > ||
99
+ std::is_same_v<typename OpA::scalar_type, matxFp16Complex> ||
100
+ std::is_same_v<typename OpA::scalar_type, matxBf16Complex>;
101
+
102
+ }
103
+ // Accumulator type different from A/B
104
+ else if constexpr ( std::is_same_v<typename OpA::scalar_type, typename OpB::scalar_type> &&
105
+ !std::is_same_v<typename OpB::scalar_type, typename OpC::scalar_type>) {
106
+ return (std::is_same_v<typename OpA::scalar_type, int8_t > && std::is_same_v<typename OpC::scalar_type, int32_t >) ||
107
+ (std::is_same_v<typename OpA::scalar_type, int8_t > && std::is_same_v<typename OpC::scalar_type, float >) ||
108
+ (std::is_same_v<typename OpA::scalar_type, matxBf16> && std::is_same_v<typename OpC::scalar_type, float >) ||
109
+ (std::is_same_v<typename OpA::scalar_type, matxFp16> && std::is_same_v<typename OpC::scalar_type, float >) ||
110
+ (std::is_same_v<typename OpA::scalar_type, int8_t > && std::is_same_v<typename OpC::scalar_type, float >);
111
+ }
112
+ }
113
+ else {
114
+ // For now return true for other providers until we support more
115
+ return true ;
116
+ }
117
+ }
118
+
80
119
/* *
81
120
* Parameters needed to execute a GEMM. For the most part, these are very
82
121
* similar to that of a standard GEMM call
@@ -834,7 +873,7 @@ class matxMatMulHandle_t {
834
873
static_cast <int >(
835
874
params_.ldc )}, // Tensor-ref for destination matrix D (may be
836
875
// different memory than source C matrix)
837
- {alpha, beta}); // Scalars used in the Epilogue
876
+ {static_cast <T1>( alpha), static_cast <T1>( beta) }); // Scalars used in the Epilogue
838
877
839
878
CutlassGemm gemm_operator;
840
879
cutlass::Status status = gemm_operator (args, nullptr , stream);
@@ -895,7 +934,7 @@ class matxMatMulHandle_t {
895
934
params_.ldc )}, // Tensor-ref for destination matrix D (may
896
935
// be different memory than source C matrix)
897
936
c_adj.Stride (RANK - 3 ), // Batch Stride C
898
- {alpha, beta},
937
+ {static_cast <T1>( alpha), static_cast <T1>( beta) },
899
938
params_.batch // Batch Dimension
900
939
); // Scalars used in the Epilogue
901
940
@@ -1118,6 +1157,10 @@ void matmul_impl(TensorTypeC C, const TensorTypeA A,
1118
1157
auto A_ = as_type<typename TensorTypeC::scalar_type>(A);
1119
1158
auto B_ = as_type<typename TensorTypeC::scalar_type>(B);
1120
1159
1160
+ static_assert (detail::CompatibleGemmTypes<decltype (A_), decltype (B_), TensorTypeC, PROV>(),
1161
+ " Combination of A/B/C types are not supported" );
1162
+
1163
+
1121
1164
// CublasLt does not support operators and certain transpose modes.
1122
1165
// Grab a suppported tensor here and copy in if necessary.
1123
1166
auto c = getCublasSupportedTensor (C, stream);
0 commit comments