Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First version of MATX Sparse-Direct-Solve (using dispatch to cuDSS) #849

Merged
merged 6 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 43 additions & 36 deletions examples/sparse_tensor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

#include "matx.h"

// Note that sparse tensor support in MatX is still experimental.

using namespace matx;

int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
Expand All @@ -42,7 +44,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
cudaExecutor exec{stream};

//
// Print some formats that are used for the versatile sparse tensor
// Print some formats that are used for the universal sparse tensor
// type. Note that common formats like COO and CSR have good library
// support in e.g. cuSPARSE, but MatX provides a much more general
// way to define the sparse tensor storage through a DSL (see doc).
Expand All @@ -68,25 +70,6 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
// | 0, 0, 0, 0, 0, 0, 0, 0 |
// | 0, 0, 3, 4, 0, 5, 0, 0 |
//

constexpr index_t m = 4;
constexpr index_t n = 8;
constexpr index_t nse = 5;

tensor_t<float, 1> values{{nse}};
tensor_t<int, 1> row_idx{{nse}};
tensor_t<int, 1> col_idx{{nse}};

values.SetVals({ 1, 2, 3, 4, 5 });
row_idx.SetVals({ 0, 0, 3, 3, 3 });
col_idx.SetVals({ 0, 1, 2, 3, 5 });

// Note that sparse tensor support in MatX is still experimental.
auto Acoo = experimental::make_tensor_coo(values, row_idx, col_idx, {m, n});

//
// This shows:
//
// tensor_impl_2_f32: SparseTensor{float} Rank: 2, Sizes:[4, 8], Levels:[4, 8]
// nse = 5
// format = ( d0, d1 ) -> ( d0 : compressed(non-unique), d1 : singleton )
Expand All @@ -95,6 +78,13 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
// values = ( 1.0000e+00 2.0000e+00 3.0000e+00 4.0000e+00 5.0000e+00 )
// space = CUDA managed memory
//
auto vals = make_tensor<float>({5});
auto idxi = make_tensor<int>({5});
auto idxj = make_tensor<int>({5});
vals.SetVals({1, 2, 3, 4, 5});
idxi.SetVals({0, 0, 3, 3, 3});
idxj.SetVals({0, 1, 2, 3, 5});
auto Acoo = experimental::make_tensor_coo(vals, idxi, idxj, {4, 8});
print(Acoo);

//
Expand All @@ -107,9 +97,9 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
// use sparse operations that are tailored for the sparse data
// structure (such as scanning by row for CSR).
//
tensor_t<float, 2> A{{m, n}};
for (index_t i = 0; i < m; i++) {
for (index_t j = 0; j < n; j++) {
auto A = make_tensor<float>({4, 8});
for (index_t i = 0; i < 4; i++) {
for (index_t j = 0; j < 8; j++) {
A(i, j) = Acoo(i, j);
}
}
Expand All @@ -119,24 +109,41 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv)
// SpMM is implemented on COO through cuSPARSE. This is the
// correct way of performing an efficient sparse operation.
//
tensor_t<float, 2> B{{8, 4}};
tensor_t<float, 2> C{{4, 4}};
B.SetVals({ { 0, 1, 2, 3 },
{ 4, 5, 6, 7 },
{ 8, 9, 10, 11 },
{ 12, 13, 14, 15 },
{ 16, 17, 18, 19 },
{ 20, 21, 22, 23 },
{ 24, 25, 26, 27 },
{ 28, 29, 30, 31 } });
auto B = make_tensor<float, 2>({8, 4});
auto C = make_tensor<float>({4, 4});
B.SetVals({
{ 0, 1, 2, 3}, { 4, 5, 6, 7}, { 8, 9, 10, 11}, {12, 13, 14, 15},
{16, 17, 18, 19}, {20, 21, 22, 23}, {24, 25, 26, 27}, {28, 29, 30, 31} });
(C = matmul(Acoo, B)).run(exec);
print(C);

//
// Verify by computing the equivelent dense GEMM.
// Creates a CSR matrix which is used to solve the following
// system of equations AX=Y, where X is the unknown.
//
(C = matmul(A, B)).run(exec);
print(C);
// | 1 2 0 0 | | 1 5 | | 5 17 |
// | 0 3 0 0 | x | 2 6 | = | 6 18 |
// | 0 0 4 0 | | 3 7 | | 12 28 |
// | 0 0 0 5 | | 4 8 | | 20 40 |
//
// Note that X and Y are presented by reshaping a 1-dim
// representation of the column-major storage, because the
// backing library cuDSS currently only supports column-major.
//
auto coeffs = make_tensor<float>({5});
auto rowptr = make_tensor<int>({5});
auto colidx = make_tensor<int>({5});
coeffs.SetVals({1, 2, 3, 4, 5});
rowptr.SetVals({0, 1, 1, 2, 3});
colidx.SetVals({0, 2, 3, 4, 5});
auto Acsr = experimental::make_tensor_csr(coeffs, rowptr, colidx, {4, 4});
print(Acsr);
// TODO: how to avoid the row-major/column-major issue?
auto X = make_tensor<float>({8});
auto Y = make_tensor<float>({8});
Y.SetVals({5, 6, 12, 20, 17, 18, 28, 40}); // col-major
aartbik marked this conversation as resolved.
Show resolved Hide resolved
(X.View({4, 2}) = solve(Acsr, Y.View({4, 2}))).run(exec);
print(X); // col-major

MATX_EXIT_HANDLER();
}
3 changes: 3 additions & 0 deletions include/matx/core/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,9 @@ template <typename T> constexpr cudaDataType_t MatXTypeToCudaType()
if constexpr (std::is_same_v<T, int8_t>) {
return CUDA_R_8I;
}
if constexpr (std::is_same_v<T, int>) {
return CUDA_R_32I;
}
if constexpr (std::is_same_v<T, float>) {
return CUDA_R_32F;
}
Expand Down
1 change: 1 addition & 0 deletions include/matx/operators/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
#include "matx/operators/shift.h"
#include "matx/operators/sign.h"
#include "matx/operators/slice.h"
#include "matx/operators/solve.h"
#include "matx/operators/sort.h"
#include "matx/operators/sph2cart.h"
#include "matx/operators/stack.h"
Expand Down
162 changes: 162 additions & 0 deletions include/matx/operators/solve.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
////////////////////////////////////////////////////////////////////////////////
// BSD 3-Clause License
//
// Copyright (c) 2025, NVIDIA Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice, this
// list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
/////////////////////////////////////////////////////////////////////////////////

#pragma once

#include "matx/core/type_utils.h"
#include "matx/operators/base_operator.h"
#ifdef MATX_EN_CUDSS
#include "matx/transforms/solve/solve_cudss.h"
#endif

namespace matx {
namespace detail {

template <typename OpA, typename OpB>
class SolveOp : public BaseOp<SolveOp<OpA, OpB>> {
private:
typename detail::base_type_t<OpA> a_;
typename detail::base_type_t<OpB> b_;

cuda::std::array<index_t, 2> out_dims_;
mutable detail::tensor_impl_t<typename OpA::value_type, 2> tmp_out_;
cliffburdick marked this conversation as resolved.
Show resolved Hide resolved
mutable typename OpA::value_type *ptr = nullptr;

public:
using matxop = bool;
using matx_transform_op = bool;
using solve_xform_op = bool;
using value_type = typename OpA::value_type;

__MATX_INLINE__ SolveOp(const OpA &a, const OpB &b) : a_(a), b_(b) {
for (int r = 0; r < Rank(); r++) {
out_dims_[r] = b_.Size(r);
}
}

__MATX_INLINE__ std::string str() const {
return "solve(" + get_type_str(a_) + "," + get_type_str(b_) + ")";
}

__MATX_HOST__ __MATX_INLINE__ auto Data() const noexcept { return ptr; }

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto)
operator()(Is... indices) const {
return tmp_out_(indices...);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t
Rank() {
return remove_cvref_t<OpB>::Rank();
}

constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t
Size(int dim) const {
return out_dims_[dim];
}

template <typename Out, typename Executor>
void Exec([[maybe_unused]] Out &&out, [[maybe_unused]] Executor &&ex) const {
static_assert(!is_sparse_tensor_v<OpB>, "sparse rhs not implemented");
if constexpr (is_sparse_tensor_v<OpA>) {
#ifdef MATX_EN_CUDSS
sparse_solve_impl(cuda::std::get<0>(out), a_, b_, ex);
#else
MATX_THROW(matxNotSupported, "Sparse direct solver requires cuDSS");
#endif
} else {
MATX_THROW(matxNotSupported,
"Direct solver currently only supports sparse system");
}
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void InnerPreRun([[maybe_unused]] ShapeType &&shape,
Executor &&ex) const noexcept {
if constexpr (is_matx_op<OpA>()) {
a_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
if constexpr (is_matx_op<OpB>()) {
b_.PreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape,
Executor &&ex) const noexcept {
InnerPreRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
detail::AllocateTempTensor(tmp_out_, std::forward<Executor>(ex), out_dims_,
&ptr);
Exec(cuda::std::make_tuple(tmp_out_), std::forward<Executor>(ex));
}

template <typename ShapeType, typename Executor>
__MATX_INLINE__ void PostRun(ShapeType &&shape,
Executor &&ex) const noexcept {
if constexpr (is_matx_op<OpA>()) {
a_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
if constexpr (is_matx_op<OpB>()) {
b_.PostRun(std::forward<ShapeType>(shape), std::forward<Executor>(ex));
}
matxFree(ptr);
}
};

} // end namespace detail

/**
* Run a direct SOLVE (viz. X = solve(A, B) solves system AX=B for unknown X).
*
* Note that currently, this operation is only implemented for solving
* a linear system with a very **sparse** matrix A.
*
* @tparam OpA
* Data type of A tensor (sparse)
* @tparam OpB
* Data type of B tensor
*
* @param A
* A Sparse tensor with system coefficients
* @param B
* B Dense tensor of known values
*
* @return
* Operator that produces the output tensor X with the solution
*/
template <typename OpA, typename OpB>
__MATX_INLINE__ auto solve(const OpA &A, const OpB &B) {
return detail::SolveOp(A, B);
}

} // end namespace matx
Loading