Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SpMM tests #875

Merged
merged 1 commit into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions include/matx/transforms/convert/dense2sparse_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,19 +258,22 @@ void dense2sparse_impl(OutputTensorType &o, const InputTensorType &a,
using TA = typename InputTensorType::value_type;
using TO = typename OutputTensorType::value_type;

static constexpr int RANKA = InputTensorType::Rank();
static constexpr int RANKO = OutputTensorType::Rank();

// Restrictions.
static_assert(OutputTensorType::Rank() == InputTensorType::Rank(),
"tensors must have same rank");
static_assert(RANKA == RANKO, "tensors must have same rank");
static_assert(std::is_same_v<TA, TO>,
"tensors must have the same data type");
static_assert(std::is_same_v<TA, int8_t> ||
std::is_same_v<TA, matx::matxFp16> ||
std::is_same_v<TA, matx::matxBf16> ||
std::is_same_v<TA, float> ||
std::is_same_v<TA, double> ||
std::is_same_v<TA, cuda::std::complex<float>> ||
std::is_same_v<TA, cuda::std::complex<double>>,
static_assert(std::is_same_v<TO, int8_t> ||
std::is_same_v<TO, matx::matxFp16> ||
std::is_same_v<TO, matx::matxBf16> ||
std::is_same_v<TO, float> ||
std::is_same_v<TO, double> ||
std::is_same_v<TO, cuda::std::complex<float>> ||
std::is_same_v<TO, cuda::std::complex<double>>,
"unsupported data type");
MATX_ASSERT(a.Stride(RANKA - 1) == 1, matxInvalidParameter);

// Get parameters required by these tensors (for caching).
auto params =
Expand Down
21 changes: 12 additions & 9 deletions include/matx/transforms/convert/sparse2dense_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,19 +218,22 @@ void sparse2dense_impl(OutputTensorType &o, const InputTensorType &a,
using TA = typename InputTensorType::value_type;
using TO = typename OutputTensorType::value_type;

static constexpr int RANKA = InputTensorType::Rank();
static constexpr int RANKO = OutputTensorType::Rank();

// Restrictions.
static_assert(OutputTensorType::Rank() == InputTensorType::Rank(),
"tensors must have same rank");
static_assert(RANKA == RANKO, "tensors must have same rank");
static_assert(std::is_same_v<TA, TO>,
"tensors must have the same data type");
static_assert(std::is_same_v<TA, int8_t> ||
std::is_same_v<TA, matx::matxFp16> ||
std::is_same_v<TA, matx::matxBf16> ||
std::is_same_v<TA, float> ||
std::is_same_v<TA, double> ||
std::is_same_v<TA, cuda::std::complex<float>> ||
std::is_same_v<TA, cuda::std::complex<double>>,
static_assert(std::is_same_v<TO, int8_t> ||
std::is_same_v<TO, matx::matxFp16> ||
std::is_same_v<TO, matx::matxBf16> ||
std::is_same_v<TO, float> ||
std::is_same_v<TO, double> ||
std::is_same_v<TO, cuda::std::complex<float>> ||
std::is_same_v<TO, cuda::std::complex<double>>,
"unsupported data type");
MATX_ASSERT(o.Stride(RANKO - 1) == 1, matxInvalidParameter);

// Get parameters required by these tensors (for caching).
auto params =
Expand Down
24 changes: 21 additions & 3 deletions include/matx/transforms/matmul/matmul_cusparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,20 @@ class MatMulCUSPARSEHandle_t {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
params_ = GetGemmParams(c, a, b, stream, alpha, beta);

// Properly typed alpha, beta.
if constexpr (std::is_same_v<TC, cuda::std::complex<float>> ||
std::is_same_v<TC, cuda::std::complex<double>>) {
salpha_ = {alpha, 0};
sbeta_ = {beta, 0};
}
else if constexpr (std::is_same_v<TC, float> ||
std::is_same_v<TC, double>) {
salpha_ = alpha;;
sbeta_ = beta;
} else {
MATX_THROW(matxNotSupported, "SpMM currently only supports uniform FP");
}

[[maybe_unused]] cusparseStatus_t ret = cusparseCreate(&handle_);
MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError);

Expand Down Expand Up @@ -137,7 +151,7 @@ class MatMulCUSPARSEHandle_t {
const cusparseSpMMAlg_t algo = CUSPARSE_SPMM_ALG_DEFAULT;
const cudaDataType comptp = dtc; // TODO: support separate comp type?!
ret = cusparseSpMM_bufferSize(handle_, params_.opA, params_.opB,
&params_.alpha, matA_, matB_, &params_.beta,
&salpha_, matA_, matB_, &sbeta_,
matC_, comptp, algo, &workspaceSize_);
MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError);
if (workspaceSize_) {
Expand Down Expand Up @@ -189,8 +203,8 @@ class MatMulCUSPARSEHandle_t {
const cusparseSpMMAlg_t algo = CUSPARSE_SPMM_ALG_DEFAULT;
const cudaDataType comptp = MatXTypeToCudaType<TC>(); // TODO: see above
[[maybe_unused]] cusparseStatus_t ret =
cusparseSpMM(handle_, params_.opA, params_.opB, &params_.alpha, matA_,
matB_, &params_.beta, matC_, comptp, algo, workspace_);
cusparseSpMM(handle_, params_.opA, params_.opB, &salpha_, matA_,
matB_, &sbeta_, matC_, comptp, algo, workspace_);
MATX_ASSERT(ret == CUSPARSE_STATUS_SUCCESS, matxMatMulError);
}

Expand All @@ -202,6 +216,8 @@ class MatMulCUSPARSEHandle_t {
size_t workspaceSize_ = 0;
void *workspace_ = nullptr;
detail::MatMulCUSPARSEParams_t params_;
TC salpha_;
TC sbeta_;
};

/**
Expand Down Expand Up @@ -272,6 +288,8 @@ void sparse_matmul_impl(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB
a.Size(RANKA - 1) == b.Size(RANKB - 2) &&
c.Size(RANKC - 1) == b.Size(RANKB - 1) &&
c.Size(RANKC - 2) == a.Size(RANKA - 2), matxInvalidSize);
MATX_ASSERT(b.Stride(RANKB - 1) == 1 &&
c.Stride(RANKC - 1) == 1, matxInvalidParameter);

// Get parameters required by these tensors (for caching).
auto params =
Expand Down
4 changes: 3 additions & 1 deletion include/matx/transforms/solve/solve_cudss.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class SolveCUDSSHandle_t {
MATX_ASSERT(ret == CUDSS_STATUS_SUCCESS, matxSolverError);

// Create cuDSS handle for dense matrices B and C.
static_assert(is_tensor_view_v<TensorTypeA>);
static_assert(is_tensor_view_v<TensorTypeB>);
static_assert(is_tensor_view_v<TensorTypeC>);
cudaDataType dtb = MatXTypeToCudaType<TB>();
cudaDataType dtc = MatXTypeToCudaType<TC>();
cudssLayout_t layout = CUDSS_LAYOUT_COL_MAJOR; // no ROW-MAJOR in cuDSS yet
Expand Down Expand Up @@ -250,6 +250,8 @@ void sparse_solve_impl_trans(TensorTypeC &c, const TensorTypeA &a,
a.Size(RANKA - 1) == b.Size(RANKB - 1) &&
a.Size(RANKA - 2) == b.Size(RANKB - 1) &&
b.Size(RANKB - 2) == c.Size(RANKC - 2), matxInvalidSize);
MATX_ASSERT(b.Stride(RANKB - 1) == 1 &&
c.Stride(RANKC - 1) == 1, matxInvalidParameter);
static_assert(std::is_same_v<typename TensorTypeA::pos_type, int32_t> &&
std::is_same_v<typename TensorTypeA::crd_type, int32_t>, "unsupported index type");

Expand Down
227 changes: 227 additions & 0 deletions test/00_sparse/Matmul.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
////////////////////////////////////////////////////////////////////////////////
// BSD 3-Clause License
//
// Copyright (c) 2025, NVIDIA Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/////////////////////////////////////////////////////////////////////////////////

#include "assert.h"
#include "matx.h"
#include "test_types.h"
#include "utilities.h"
#include "gtest/gtest.h"

using namespace matx;

//
// Helper method to construct:
//
// | 1, 2, 0, 0, 0, 0, 0, 0 |
// | 0, 0, 0, 0, 0, 0, 0, 3 |
// | 0, 0, 0, 0, 0, 0, 4, 0 |
// | 0, 0, 5, 6, 0, 7, 0, 0 |
//
template <typename T> static auto makeA() {
const index_t m = 4;
const index_t n = 8;
tensor_t<T, 2> A = make_tensor<T>({m, n});
for (index_t i = 0; i < m; i++) {
for (index_t j = 0; j < n; j++) {
A(i, j) = static_cast<T>(0);
}
}
A(0, 0) = static_cast<T>(1);
A(0, 1) = static_cast<T>(2);
A(1, 7) = static_cast<T>(3);
A(2, 6) = static_cast<T>(4);
A(3, 2) = static_cast<T>(5);
A(3, 3) = static_cast<T>(6);
A(3, 5) = static_cast<T>(7);
return A;
}

template <typename T> static auto makeB() {
const index_t m = 8;
const index_t n = 2;
tensor_t<T, 2> B = make_tensor<T>({m, n});
B(0, 0) = static_cast<T>(1); B(0, 1) = static_cast<T>(2);
B(1, 0) = static_cast<T>(3); B(1, 1) = static_cast<T>(4);
B(2, 0) = static_cast<T>(5); B(2, 1) = static_cast<T>(6);
B(3, 0) = static_cast<T>(7); B(3, 1) = static_cast<T>(8);
B(4, 0) = static_cast<T>(9); B(4, 1) = static_cast<T>(10);
B(5, 0) = static_cast<T>(11); B(5, 1) = static_cast<T>(12);
B(6, 0) = static_cast<T>(13); B(6, 1) = static_cast<T>(14);
B(7, 0) = static_cast<T>(15); B(7, 1) = static_cast<T>(16);
return B;
}

template <typename T> static auto makeE() {
const index_t m = 4;
const index_t n = 2;
tensor_t<T, 2> E = make_tensor<T>({m, n});
E(0, 0) = static_cast<T>(7); E(0, 1) = static_cast<T>(10);
E(1, 0) = static_cast<T>(45); E(1, 1) = static_cast<T>(48);
E(2, 0) = static_cast<T>(52); E(2, 1) = static_cast<T>(56);
E(3, 0) = static_cast<T>(144); E(3, 1) = static_cast<T>(162);
return E;
}

template <typename T> class MatmulSparseTest : public ::testing::Test {
protected:
float thresh = 0.001f;
};

template <typename T> class MatmulSparseTestsAll : public MatmulSparseTest<T> { };

TYPED_TEST_SUITE(MatmulSparseTestsAll, MatXFloatNonHalfTypesCUDAExec);

TYPED_TEST(MatmulSparseTestsAll, MatmulCOO) {
MATX_ENTER_HANDLER();
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;

ExecType exec{};

auto A = makeA<TestType>();
auto B = makeB<TestType>();
auto E = makeE<TestType>();
const auto m = A.Size(0);
const auto k = A.Size(1);
const auto n = B.Size(1);

// Convert dense A to sparse S.
auto S = experimental::make_zero_tensor_coo<TestType, index_t>({m, k});
(S = dense2sparse(A)).run(exec);
ASSERT_EQ(S.Nse(), 7);
ASSERT_EQ(S.posSize(1), 0);

// Matmul.
auto O = make_tensor<TestType>({m, n});
(O = matmul(S, B)).run(exec);

// Verify result.
exec.sync();
for (index_t i = 0; i < m; i++) {
for (index_t j = 0; j < n; j++) {
if constexpr (is_complex_v<TestType>) {
ASSERT_NEAR(O(i, j).real(), E(i, j).real(), this->thresh);
ASSERT_NEAR(O(i, j).imag(), E(i,j ).imag(), this->thresh);
}
else {
ASSERT_NEAR(O(i, j), E(i, j), this->thresh);
}

}
}

MATX_EXIT_HANDLER();
}

TYPED_TEST(MatmulSparseTestsAll, MatmulCSR) {
MATX_ENTER_HANDLER();
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;

ExecType exec{};

auto A = makeA<TestType>();
auto B = makeB<TestType>();
auto E = makeE<TestType>();
const auto m = A.Size(0);
const auto k = A.Size(1);
const auto n = B.Size(1);

// Convert dense A to sparse S.
auto S = experimental::make_zero_tensor_csr<TestType, index_t, index_t>({m, k});
(S = dense2sparse(A)).run(exec);
ASSERT_EQ(S.Nse(), 7);
ASSERT_EQ(S.posSize(1), m + 1);

// Matmul.
auto O = make_tensor<TestType>({m, n});
(O = matmul(S, B)).run(exec);

// Verify result.
exec.sync();
for (index_t i = 0; i < m; i++) {
for (index_t j = 0; j < n; j++) {
if constexpr (is_complex_v<TestType>) {
ASSERT_NEAR(O(i, j).real(), E(i, j).real(), this->thresh);
ASSERT_NEAR(O(i, j).imag(), E(i,j ).imag(), this->thresh);
}
else {
ASSERT_NEAR(O(i, j), E(i, j), this->thresh);
}

}
}

MATX_EXIT_HANDLER();
}

TYPED_TEST(MatmulSparseTestsAll, MatmulCSC) {
MATX_ENTER_HANDLER();
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;

ExecType exec{};

auto A = makeA<TestType>();
auto B = makeB<TestType>();
auto E = makeE<TestType>();
const auto m = A.Size(0);
const auto k = A.Size(1);
const auto n = B.Size(1);

// Convert dense A to sparse S.
auto S = experimental::make_zero_tensor_csc<TestType, index_t, index_t>({m, k});
(S = dense2sparse(A)).run(exec);
ASSERT_EQ(S.Nse(), 7);
ASSERT_EQ(S.posSize(1), k + 1);

// Matmul.
auto O = make_tensor<TestType>({m, n});
(O = matmul(S, B)).run(exec);

// Verify result.
exec.sync();
for (index_t i = 0; i < m; i++) {
for (index_t j = 0; j < n; j++) {
if constexpr (is_complex_v<TestType>) {
ASSERT_NEAR(O(i, j).real(), E(i, j).real(), this->thresh);
ASSERT_NEAR(O(i, j).imag(), E(i,j ).imag(), this->thresh);
}
else {
ASSERT_NEAR(O(i, j), E(i, j), this->thresh);
}

}
}

MATX_EXIT_HANDLER();
}
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ set (test_sources
01_radar/dct.cu
00_sparse/Basic.cu
00_sparse/Convert.cu
00_sparse/Matmul.cu
00_sparse/Solve.cu
main.cu
)
Expand Down