From a3703a1c6498fee0bb6a4a9b87a68fe63d2bc887 Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Thu, 30 Jan 2025 13:09:10 -0800 Subject: [PATCH 1/6] First version of MATX Sparse-Direct-Solve (using dispatch to cuDSS) --- examples/sparse_tensor.cu | 30 +- include/matx/core/type_utils.h | 3 + include/matx/operators/operators.h | 1 + include/matx/operators/solve.h | 162 +++++++++++ include/matx/transforms/solve/solve_cudss.h | 286 ++++++++++++++++++++ 5 files changed, 481 insertions(+), 1 deletion(-) create mode 100644 include/matx/operators/solve.h create mode 100644 include/matx/transforms/solve/solve_cudss.h diff --git a/examples/sparse_tensor.cu b/examples/sparse_tensor.cu index 241ee8e3..a373b70e 100644 --- a/examples/sparse_tensor.cu +++ b/examples/sparse_tensor.cu @@ -133,10 +133,38 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) print(C); // - // Verify by computing the equivelent dense GEMM. + // Verify by computing the equivalent dense GEMM. // (C = matmul(A, B)).run(exec); print(C); + // + // Creates a CSR matrix which is used to solve the following + // system of equations AX=Y, where X is the unknown. + // + // | 1 2 0 0 | | 1 5 | | 5 17 | + // | 0 3 0 0 | x | 2 6 | = | 6 18 | + // | 0 0 4 0 | | 3 7 | | 12 28 | + // | 0 0 0 5 | | 4 8 | | 20 40 | + // + // Note that X and Y are presented by reshaping a 1-dim + // representation of the column-major storage, because the + // underlying library currently only supports column-major. + // + tensor_t coeffs{{nse}}; + tensor_t rowptr{{nse}}; + tensor_t colidx{{nse}}; + coeffs.SetVals({ 1, 2, 3, 4, 5 }); + rowptr.SetVals({ 0, 2, 3, 4, 5 }); + colidx.SetVals({ 0, 1, 1, 2, 3 }); + auto Acsr = experimental::make_tensor_csr(coeffs, rowptr, colidx, {4, 4}); + print(Acsr); + tensor_t Y{{8}}; + tensor_t X{{8}}; + // TODO: how to avoid the row-major/column-major issue? + Y.SetVals({5, 6, 12, 20, 17, 18, 28, 40}); // col-major + (X.View({4, 2}) = solve(Acsr, Y.View({4, 2}))).run(exec); + print(X); // col-major + MATX_EXIT_HANDLER(); } diff --git a/include/matx/core/type_utils.h b/include/matx/core/type_utils.h index 3846ad8b..cf164db6 100644 --- a/include/matx/core/type_utils.h +++ b/include/matx/core/type_utils.h @@ -1126,6 +1126,9 @@ template constexpr cudaDataType_t MatXTypeToCudaType() if constexpr (std::is_same_v) { return CUDA_R_8I; } + if constexpr (std::is_same_v) { + return CUDA_R_32I; + } if constexpr (std::is_same_v) { return CUDA_R_32F; } diff --git a/include/matx/operators/operators.h b/include/matx/operators/operators.h index 78c86aef..9d800ecc 100644 --- a/include/matx/operators/operators.h +++ b/include/matx/operators/operators.h @@ -99,6 +99,7 @@ #include "matx/operators/shift.h" #include "matx/operators/sign.h" #include "matx/operators/slice.h" +#include "matx/operators/solve.h" #include "matx/operators/sort.h" #include "matx/operators/sph2cart.h" #include "matx/operators/stack.h" diff --git a/include/matx/operators/solve.h b/include/matx/operators/solve.h new file mode 100644 index 00000000..4116c0a9 --- /dev/null +++ b/include/matx/operators/solve.h @@ -0,0 +1,162 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2025, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "matx/core/type_utils.h" +#include "matx/operators/base_operator.h" +#ifdef MATX_EN_CUDSS +#include "matx/transforms/solve/solve_cudss.h" +#endif + +namespace matx { +namespace detail { + +template +class SolveOp : public BaseOp> { +private: + typename detail::base_type_t a_; + typename detail::base_type_t b_; + + cuda::std::array out_dims_; + mutable detail::tensor_impl_t tmp_out_; + mutable typename OpA::value_type *ptr = nullptr; + +public: + using matxop = bool; + using matx_transform_op = bool; + using solve_xform_op = bool; + using value_type = typename OpA::value_type; + + __MATX_INLINE__ SolveOp(const OpA &a, const OpB &b) : a_(a), b_(b) { + for (int r = 0; r < Rank(); r++) { + out_dims_[r] = b_.Size(r); + } + } + + __MATX_INLINE__ std::string str() const { + return "solve(" + get_type_str(a_) + "," + get_type_str(b_) + ")"; + } + + __MATX_HOST__ __MATX_INLINE__ auto Data() const noexcept { return ptr; } + + template + __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ decltype(auto) + operator()(Is... indices) const { + return tmp_out_(indices...); + } + + static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t + Rank() { + return remove_cvref_t::Rank(); + } + + constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t + Size(int dim) const { + return out_dims_[dim]; + } + + template + void Exec([[maybe_unused]] Out &&out, [[maybe_unused]] Executor &&ex) const { + static_assert(!is_sparse_tensor_v, "sparse rhs not implemented"); + if constexpr (is_sparse_tensor_v) { +#ifdef MATX_EN_CUDSS + sparse_solve_impl(cuda::std::get<0>(out), a_, b_, ex); +#else + MATX_THROW(matxNotSupported, "Sparse direct solver requires cuDSS"); +#endif + } else { + MATX_THROW(matxNotSupported, + "Direct solver currently only supports sparse system"); + } + } + + template + __MATX_INLINE__ void InnerPreRun([[maybe_unused]] ShapeType &&shape, + Executor &&ex) const noexcept { + if constexpr (is_matx_op()) { + a_.PreRun(std::forward(shape), std::forward(ex)); + } + if constexpr (is_matx_op()) { + b_.PreRun(std::forward(shape), std::forward(ex)); + } + } + + template + __MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, + Executor &&ex) const noexcept { + InnerPreRun(std::forward(shape), std::forward(ex)); + detail::AllocateTempTensor(tmp_out_, std::forward(ex), out_dims_, + &ptr); + Exec(cuda::std::make_tuple(tmp_out_), std::forward(ex)); + } + + template + __MATX_INLINE__ void PostRun(ShapeType &&shape, + Executor &&ex) const noexcept { + if constexpr (is_matx_op()) { + a_.PostRun(std::forward(shape), std::forward(ex)); + } + if constexpr (is_matx_op()) { + b_.PostRun(std::forward(shape), std::forward(ex)); + } + matxFree(ptr); + } +}; + +} // end namespace detail + +/** + * Run a direct SOLVE (viz. X = solve(A, B) solves system AX=B for unknown X). + * + * Note that currently, this operation is only implemented for solving + * a linear system with a very **sparse** matrix A. + * + * @tparam OpA + * Data type of A tensor (sparse) + * @tparam OpB + * Data type of B tensor + * + * @param A + * A Sparse tensor with system coefficients + * @param B + * B Dense tensor of known values + * + * @return + * Operator that produces the output tensor X with the solution + */ +template +__MATX_INLINE__ auto solve(const OpA &A, const OpB &B) { + return detail::SolveOp(A, B); +} + +} // end namespace matx diff --git a/include/matx/transforms/solve/solve_cudss.h b/include/matx/transforms/solve/solve_cudss.h new file mode 100644 index 00000000..047e7563 --- /dev/null +++ b/include/matx/transforms/solve/solve_cudss.h @@ -0,0 +1,286 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2025, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include + +#include + +#include "matx/core/cache.h" +#include "matx/core/sparse_tensor.h" +#include "matx/core/tensor.h" + +namespace matx { + +namespace detail { + +/** + * Parameters needed to execute a cuDSS direct SOLVE. + */ +struct SolveCUDSSParams_t { + MatXDataType_t dtype; + MatXDataType_t ptype; + MatXDataType_t ctype; + int rank; + cudaStream_t stream; + index_t nse; + index_t m; + index_t n; + index_t k; + // Matrix handles in cuDSS are data specific (unlike e.g. cuBLAS + // where the same plan can be shared between different data buffers). + void *ptrA0; + void *ptrA1; + void *ptrA2; + void *ptrA3; + void *ptrA4; + void *ptrB; + void *ptrC; +}; + +template +class SolveCUDSSHandle_t { +public: + using TA = typename TensorTypeA::value_type; + using TB = typename TensorTypeB::value_type; + using TC = typename TensorTypeC::value_type; + + static constexpr int RANKA = TensorTypeC::Rank(); + static constexpr int RANKB = TensorTypeC::Rank(); + static constexpr int RANKC = TensorTypeC::Rank(); + + SolveCUDSSHandle_t(TensorTypeC &c, const TensorTypeA &a, const TensorTypeB &b, + cudaStream_t stream) { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) + + static_assert(RANKA == 2); + static_assert(RANKB == 2); + static_assert(RANKC == 2); + + MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 2), matxInvalidSize); + MATX_ASSERT(c.Size(RANKC - 1) == b.Size(RANKB - 1), matxInvalidSize); + MATX_ASSERT(c.Size(RANKC - 2) == a.Size(RANKA - 2), matxInvalidSize); + + params_ = GetSolveParams(c, a, b, stream); + + [[maybe_unused]] cudssStatus_t ret = cudssCreate(&handle_); + MATX_ASSERT(ret == CUDSS_STATUS_SUCCESS, matxSolverError); + + // Create cuDSS handle for sparse matrix A. + static_assert(is_sparse_tensor_v); + MATX_ASSERT(TypeToInt == + TypeToInt, + matxNotSupported); + cudaDataType itp = MatXTypeToCudaType(); + cudaDataType dta = MatXTypeToCudaType(); + cudssMatrixType_t mtp = CUDSS_MTYPE_GENERAL; + cudssMatrixViewType_t mvw = CUDSS_MVIEW_FULL; + cudssIndexBase_t bas = CUDSS_BASE_ZERO; + if constexpr (TensorTypeA::Format::isCSR()) { + ret = cudssMatrixCreateCsr(&matA_, params_.m, params_.k, params_.nse, + /*rowStart=*/params_.ptrA2, + /*rowEnd=*/nullptr, params_.ptrA4, + params_.ptrA0, itp, dta, mtp, mvw, bas); + } else { + MATX_THROW(matxNotSupported, "cuDSS currently only supports CSR"); + } + MATX_ASSERT(ret == CUDSS_STATUS_SUCCESS, matxSolverError); + + // Create cuDSS handle for dense matrices B and C. + static_assert(is_tensor_view_v); + static_assert(is_tensor_view_v); + cudaDataType dtb = MatXTypeToCudaType(); + cudaDataType dtc = MatXTypeToCudaType(); + cudssLayout_t layout = CUDSS_LAYOUT_COL_MAJOR; // TODO: no ROW + ret = cudssMatrixCreateDn(&matB_, params_.k, params_.n, /*ld=*/params_.k, + params_.ptrB, dtb, layout); + MATX_ASSERT(ret == CUDSS_STATUS_SUCCESS, matxSolverError); + ret = cudssMatrixCreateDn(&matC_, params_.m, params_.n, /*ld=*/params_.m, + params_.ptrC, dtc, layout); + MATX_ASSERT(ret == CUDSS_STATUS_SUCCESS, matxSolverError); + + // Allocate configuration and data. + ret = cudssConfigCreate(&config_); + MATX_ASSERT(ret == CUDSS_STATUS_SUCCESS, matxSolverError); + ret = cudssDataCreate(handle_, &data_); + MATX_ASSERT(ret == CUDSS_STATUS_SUCCESS, matxSolverError); + + // Set configuration. + cudssAlgType_t reorder_alg = CUDSS_ALG_DEFAULT; + cudssConfigParam_t par = CUDSS_CONFIG_REORDERING_ALG; + ret = cudssConfigSet(config_, par, &reorder_alg, sizeof(cudssAlgType_t)); + MATX_ASSERT(ret == CUDSS_STATUS_SUCCESS, matxSolverError); + } + + ~SolveCUDSSHandle_t() { + cudssConfigDestroy(config_); + cudssDataDestroy(handle_, data_); + cudssDestroy(handle_); + } + + static detail::SolveCUDSSParams_t GetSolveParams(TensorTypeC &c, + const TensorTypeA &a, + const TensorTypeB &b, + cudaStream_t stream) { + detail::SolveCUDSSParams_t params; + params.dtype = TypeToInt(); + params.ptype = TypeToInt(); + params.ctype = TypeToInt(); + params.rank = c.Rank(); + params.stream = stream; + // TODO: simple no-batch, row-wise, no-transpose for now + params.nse = a.Nse(); + params.m = a.Size(TensorTypeA::Rank() - 2); + params.n = b.Size(TensorTypeB::Rank() - 1); + params.k = a.Size(TensorTypeB::Rank() - 1); + // Matrix handles in cuDSS are data specific. Therefore, the pointers + // to the underlying buffers are part of the GEMM parameters. + params.ptrA0 = a.Data(); + params.ptrA1 = a.POSData(0); + params.ptrA2 = a.POSData(1); + params.ptrA3 = a.CRDData(0); + params.ptrA4 = a.CRDData(1); + params.ptrB = b.Data(); + params.ptrC = c.Data(); + return params; + } + + __MATX_INLINE__ void Exec([[maybe_unused]] TensorTypeC &c, + [[maybe_unused]] const TensorTypeA &a, + [[maybe_unused]] const TensorTypeB &b) { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL); + // TODO: provide a way to expose these three different steps + // (analysis/factorization/solve) individually to user? + [[maybe_unused]] cudssStatus_t ret = cudssExecute( + handle_, CUDSS_PHASE_ANALYSIS, config_, data_, matA_, matC_, matB_); + MATX_ASSERT(ret == CUDSS_STATUS_SUCCESS, matxSolverError); + ret = cudssExecute(handle_, CUDSS_PHASE_FACTORIZATION, config_, data_, + matA_, matC_, matB_); + MATX_ASSERT(ret == CUDSS_STATUS_SUCCESS, matxSolverError); + ret = cudssExecute(handle_, CUDSS_PHASE_SOLVE, config_, data_, matA_, matC_, + matB_); + MATX_ASSERT(ret == CUDSS_STATUS_SUCCESS, matxSolverError); + } + +private: + cudssHandle_t handle_ = nullptr; // TODO: share handle globally? + cudssConfig_t config_ = nullptr; + cudssData_t data_ = nullptr; + cudssMatrix_t matA_ = nullptr; + cudssMatrix_t matB_ = nullptr; + cudssMatrix_t matC_ = nullptr; + detail::SolveCUDSSParams_t params_; +}; + +/** + * Crude hash on SOLVE to get a reasonably good delta for collisions. This + * doesn't need to be perfect, but fast enough to not slow down lookups, and + * different enough so the common SOLVE parameters change. + */ +struct SolveCUDSSParamsKeyHash { + std::size_t operator()(const SolveCUDSSParams_t &k) const noexcept { + return std::hash()(reinterpret_cast(k.ptrA0)) + + std::hash()(reinterpret_cast(k.ptrB)) + + std::hash()(reinterpret_cast(k.ptrC)) + + std::hash()(reinterpret_cast(k.stream)); + } +}; + +/** + * Test SOLVE parameters for equality. Unlike the hash, all parameters must + * match exactly to ensure the hashed kernel can be reused for the computation. + */ +struct SolveCUDSSParamsKeyEq { + bool operator()(const SolveCUDSSParams_t &l, + const SolveCUDSSParams_t &t) const noexcept { + return l.dtype == t.dtype && l.ptype == t.ptype && l.ctype == t.ctype && + l.rank == t.rank && l.stream == t.stream && l.nse == t.nse && + l.m == t.m && l.n == t.n && l.k == t.k && l.ptrA0 == t.ptrA0 && + l.ptrA1 == t.ptrA1 && l.ptrA2 == t.ptrA2 && l.ptrA3 == t.ptrA3 && + l.ptrA4 == t.ptrA4 && l.ptrB == t.ptrB && l.ptrC == t.ptrC; + } +}; + +using gemm_cudss_cache_t = + std::unordered_map; + +} // end namespace detail + +template +__MATX_INLINE__ auto getCUDSSSupportedTensor(const Op &in, + cudaStream_t stream) { + const auto support_func = [&in]() { + if constexpr (is_tensor_view_v) { + // TODO: we check for row-major even though we assume + // data is actually stored column-major; this is + // waiting for row-major support in the cuDSS lib + return in.Stride(Op::Rank() - 1) == 1; + } else { + return true; + } + }; + return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream); +} + +template +void sparse_solve_impl(TensorTypeC C, const TensorTypeA A, const TensorTypeB B, + const cudaExecutor &exec) { + MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) + const auto stream = exec.getStream(); + + auto a = A; // always sparse + auto b = getCUDSSSupportedTensor(B, stream); + auto c = getCUDSSSupportedTensor(C, stream); + + // TODO: some more checking, supported type? on device? etc. + + typedef decltype(c) ctype; + typedef decltype(a) atype; + typedef decltype(b) btype; + + // Get parameters required by these tensors (for caching). + auto params = detail::SolveCUDSSHandle_t::GetSolveParams( + c, a, b, stream); + + // Lookup and cache. + using cache_val_type = detail::SolveCUDSSHandle_t; + detail::GetCache().LookupAndExec( + detail::GetCacheIdFromType(), params, + [&]() { return std::make_shared(c, a, b, stream); }, + [&](std::shared_ptr cache_type) { + cache_type->Exec(c, a, b); + }); +} + +} // end namespace matx From 6a2d9015d543678c3fe9a71d0c553758f754f8f9 Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Thu, 30 Jan 2025 14:56:40 -0800 Subject: [PATCH 2/6] changed to make_tensor factory methods --- examples/sparse_tensor.cu | 77 ++++++++++++++------------------------- 1 file changed, 28 insertions(+), 49 deletions(-) diff --git a/examples/sparse_tensor.cu b/examples/sparse_tensor.cu index a373b70e..eb2ff07d 100644 --- a/examples/sparse_tensor.cu +++ b/examples/sparse_tensor.cu @@ -32,6 +32,8 @@ #include "matx.h" +// Note that sparse tensor support in MatX is still experimental. + using namespace matx; int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) @@ -42,7 +44,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) cudaExecutor exec{stream}; // - // Print some formats that are used for the versatile sparse tensor + // Print some formats that are used for the universal sparse tensor // type. Note that common formats like COO and CSR have good library // support in e.g. cuSPARSE, but MatX provides a much more general // way to define the sparse tensor storage through a DSL (see doc). @@ -68,25 +70,6 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) // | 0, 0, 0, 0, 0, 0, 0, 0 | // | 0, 0, 3, 4, 0, 5, 0, 0 | // - - constexpr index_t m = 4; - constexpr index_t n = 8; - constexpr index_t nse = 5; - - tensor_t values{{nse}}; - tensor_t row_idx{{nse}}; - tensor_t col_idx{{nse}}; - - values.SetVals({ 1, 2, 3, 4, 5 }); - row_idx.SetVals({ 0, 0, 3, 3, 3 }); - col_idx.SetVals({ 0, 1, 2, 3, 5 }); - - // Note that sparse tensor support in MatX is still experimental. - auto Acoo = experimental::make_tensor_coo(values, row_idx, col_idx, {m, n}); - - // - // This shows: - // // tensor_impl_2_f32: SparseTensor{float} Rank: 2, Sizes:[4, 8], Levels:[4, 8] // nse = 5 // format = ( d0, d1 ) -> ( d0 : compressed(non-unique), d1 : singleton ) @@ -95,6 +78,13 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) // values = ( 1.0000e+00 2.0000e+00 3.0000e+00 4.0000e+00 5.0000e+00 ) // space = CUDA managed memory // + auto vals = make_tensor({5}); + auto idxi = make_tensor({5}); + auto idxj = make_tensor({5}); + vals.SetVals({1, 2, 3, 4, 5}); + idxi.SetVals({0, 0, 3, 3, 3}); + idxj.SetVals({0, 1, 2, 3, 5}); + auto Acoo = experimental::make_tensor_coo(vals, idxi, idxj, {4, 8}); print(Acoo); // @@ -107,9 +97,9 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) // use sparse operations that are tailored for the sparse data // structure (such as scanning by row for CSR). // - tensor_t A{{m, n}}; - for (index_t i = 0; i < m; i++) { - for (index_t j = 0; j < n; j++) { + auto A = make_tensor({4, 8}); + for (index_t i = 0; i < 4; i++) { + for (index_t j = 0; j < 8; j++) { A(i, j) = Acoo(i, j); } } @@ -119,25 +109,14 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) // SpMM is implemented on COO through cuSPARSE. This is the // correct way of performing an efficient sparse operation. // - tensor_t B{{8, 4}}; - tensor_t C{{4, 4}}; - B.SetVals({ { 0, 1, 2, 3 }, - { 4, 5, 6, 7 }, - { 8, 9, 10, 11 }, - { 12, 13, 14, 15 }, - { 16, 17, 18, 19 }, - { 20, 21, 22, 23 }, - { 24, 25, 26, 27 }, - { 28, 29, 30, 31 } }); + auto B = make_tensor({8, 4}); + auto C = make_tensor({4, 4}); + B.SetVals({ + { 0, 1, 2, 3}, { 4, 5, 6, 7}, { 8, 9, 10, 11}, {12, 13, 14, 15}, + {16, 17, 18, 19}, {20, 21, 22, 23}, {24, 25, 26, 27}, {28, 29, 30, 31} }); (C = matmul(Acoo, B)).run(exec); print(C); - // - // Verify by computing the equivalent dense GEMM. - // - (C = matmul(A, B)).run(exec); - print(C); - // // Creates a CSR matrix which is used to solve the following // system of equations AX=Y, where X is the unknown. @@ -149,19 +128,19 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) // // Note that X and Y are presented by reshaping a 1-dim // representation of the column-major storage, because the - // underlying library currently only supports column-major. - // - tensor_t coeffs{{nse}}; - tensor_t rowptr{{nse}}; - tensor_t colidx{{nse}}; - coeffs.SetVals({ 1, 2, 3, 4, 5 }); - rowptr.SetVals({ 0, 2, 3, 4, 5 }); - colidx.SetVals({ 0, 1, 1, 2, 3 }); + // backing library cuDSS currently only supports column-major. + // + auto coeffs = make_tensor({5}); + auto rowptr = make_tensor({5}); + auto colidx = make_tensor({5}); + coeffs.SetVals({1, 2, 3, 4, 5}); + rowptr.SetVals({0, 1, 1, 2, 3}); + colidx.SetVals({0, 2, 3, 4, 5}); auto Acsr = experimental::make_tensor_csr(coeffs, rowptr, colidx, {4, 4}); print(Acsr); - tensor_t Y{{8}}; - tensor_t X{{8}}; // TODO: how to avoid the row-major/column-major issue? + auto X = make_tensor({8}); + auto Y = make_tensor({8}); Y.SetVals({5, 6, 12, 20, 17, 18, 28, 40}); // col-major (X.View({4, 2}) = solve(Acsr, Y.View({4, 2}))).run(exec); print(X); // col-major From f653a1ec3539edf64656366d92370390eced51a0 Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Fri, 31 Jan 2025 15:01:06 -0800 Subject: [PATCH 3/6] hide col-major issues from user --- examples/sparse_tensor.cu | 19 +++--- include/matx/transforms/solve/solve_cudss.h | 64 ++++++++++++++------- 2 files changed, 51 insertions(+), 32 deletions(-) diff --git a/examples/sparse_tensor.cu b/examples/sparse_tensor.cu index eb2ff07d..b2659991 100644 --- a/examples/sparse_tensor.cu +++ b/examples/sparse_tensor.cu @@ -126,24 +126,19 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) // | 0 0 4 0 | | 3 7 | | 12 28 | // | 0 0 0 5 | | 4 8 | | 20 40 | // - // Note that X and Y are presented by reshaping a 1-dim - // representation of the column-major storage, because the - // backing library cuDSS currently only supports column-major. - // auto coeffs = make_tensor({5}); auto rowptr = make_tensor({5}); auto colidx = make_tensor({5}); coeffs.SetVals({1, 2, 3, 4, 5}); - rowptr.SetVals({0, 1, 1, 2, 3}); - colidx.SetVals({0, 2, 3, 4, 5}); + rowptr.SetVals({0, 2, 3, 4, 5}); + colidx.SetVals({0, 1, 1, 2, 3}); auto Acsr = experimental::make_tensor_csr(coeffs, rowptr, colidx, {4, 4}); print(Acsr); - // TODO: how to avoid the row-major/column-major issue? - auto X = make_tensor({8}); - auto Y = make_tensor({8}); - Y.SetVals({5, 6, 12, 20, 17, 18, 28, 40}); // col-major - (X.View({4, 2}) = solve(Acsr, Y.View({4, 2}))).run(exec); - print(X); // col-major + auto X = make_tensor({4, 2}); + auto Y = make_tensor({4, 2}); + Y.SetVals({ {5, 17}, {6, 18}, {12, 28}, {20, 40} }); + (X = solve(Acsr, Y)).run(exec); + print(X); MATX_EXIT_HANDLER(); } diff --git a/include/matx/transforms/solve/solve_cudss.h b/include/matx/transforms/solve/solve_cudss.h index 047e7563..6d303216 100644 --- a/include/matx/transforms/solve/solve_cudss.h +++ b/include/matx/transforms/solve/solve_cudss.h @@ -87,9 +87,10 @@ class SolveCUDSSHandle_t { static_assert(RANKB == 2); static_assert(RANKC == 2); - MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 2), matxInvalidSize); - MATX_ASSERT(c.Size(RANKC - 1) == b.Size(RANKB - 1), matxInvalidSize); - MATX_ASSERT(c.Size(RANKC - 2) == a.Size(RANKA - 2), matxInvalidSize); + // Note: B,C transposed! + MATX_ASSERT(a.Size(RANKA - 1) == b.Size(RANKB - 1), matxInvalidSize); + MATX_ASSERT(a.Size(RANKA - 2) == b.Size(RANKB - 1), matxInvalidSize); + MATX_ASSERT(b.Size(RANKB - 2) == c.Size(RANKC - 2), matxInvalidSize); params_ = GetSolveParams(c, a, b, stream); @@ -121,11 +122,11 @@ class SolveCUDSSHandle_t { static_assert(is_tensor_view_v); cudaDataType dtb = MatXTypeToCudaType(); cudaDataType dtc = MatXTypeToCudaType(); - cudssLayout_t layout = CUDSS_LAYOUT_COL_MAJOR; // TODO: no ROW - ret = cudssMatrixCreateDn(&matB_, params_.k, params_.n, /*ld=*/params_.k, + cudssLayout_t layout = CUDSS_LAYOUT_COL_MAJOR; // no ROW-MAJOR in cuDSS yet + ret = cudssMatrixCreateDn(&matB_, params_.m, params_.n, /*ld=*/params_.m, params_.ptrB, dtb, layout); MATX_ASSERT(ret == CUDSS_STATUS_SUCCESS, matxSolverError); - ret = cudssMatrixCreateDn(&matC_, params_.m, params_.n, /*ld=*/params_.m, + ret = cudssMatrixCreateDn(&matC_, params_.k, params_.n, /*ld=*/params_.k, params_.ptrC, dtc, layout); MATX_ASSERT(ret == CUDSS_STATUS_SUCCESS, matxSolverError); @@ -161,8 +162,8 @@ class SolveCUDSSHandle_t { // TODO: simple no-batch, row-wise, no-transpose for now params.nse = a.Nse(); params.m = a.Size(TensorTypeA::Rank() - 2); - params.n = b.Size(TensorTypeB::Rank() - 1); - params.k = a.Size(TensorTypeB::Rank() - 1); + params.n = c.Size(TensorTypeC::Rank() - 2); // Note: B,C transposed! + params.k = a.Size(TensorTypeA::Rank() - 1); // Matrix handles in cuDSS are data specific. Therefore, the pointers // to the underlying buffers are part of the GEMM parameters. params.ptrA0 = a.Data(); @@ -240,22 +241,13 @@ using gemm_cudss_cache_t = template __MATX_INLINE__ auto getCUDSSSupportedTensor(const Op &in, cudaStream_t stream) { - const auto support_func = [&in]() { - if constexpr (is_tensor_view_v) { - // TODO: we check for row-major even though we assume - // data is actually stored column-major; this is - // waiting for row-major support in the cuDSS lib - return in.Stride(Op::Rank() - 1) == 1; - } else { - return true; - } - }; + const auto support_func = [&in]() { return true; }; return GetSupportedTensor(in, support_func, MATX_ASYNC_DEVICE_MEMORY, stream); } template -void sparse_solve_impl(TensorTypeC C, const TensorTypeA A, const TensorTypeB B, - const cudaExecutor &exec) { +void sparse_solve_impl_trans(TensorTypeC C, const TensorTypeA A, + const TensorTypeB B, const cudaExecutor &exec) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) const auto stream = exec.getStream(); @@ -283,4 +275,36 @@ void sparse_solve_impl(TensorTypeC C, const TensorTypeA A, const TensorTypeB B, }); } +// Since cuDSS currently only supports column-major storage of the dense +// matrices (and CSR for the sparse matrix), the current implementation +// tranposes B and C prior to entering a tranposed version for SOLVE. This +// convoluted way of performing the solve step must be removed once cuDSS +// supports MATX native row-major storage, which will clean up the copies from +// and to memory. +template +void sparse_solve_impl(TensorTypeC C, const TensorTypeA A, const TensorTypeB B, + const cudaExecutor &exec) { + const auto stream = exec.getStream(); + + // Some copying-in hacks, assumes rank 2. + using TB = typename TensorTypeB::value_type; + using TC = typename TensorTypeB::value_type; + TB *bptr; + matxAlloc(reinterpret_cast(&bptr), + sizeof(TB) * B.Size(0) * B.Size(1), MATX_ASYNC_DEVICE_MEMORY, + stream); + auto bT = make_tensor(bptr, {B.Size(1), B.Size(0)}); + (bT = transpose(B)).run(exec); + TC *cptr; + matxAlloc(reinterpret_cast(&cptr), + sizeof(TC) * C.Size(0) * C.Size(1), MATX_ASYNC_DEVICE_MEMORY, + stream); + auto cT = make_tensor(cptr, {C.Size(1), C.Size(0)}); + + sparse_solve_impl_trans(cT, A, bT, exec); + + // Some copying-back hacks. + (C = transpose(cT)).run(exec); +} + } // end namespace matx From 2a1356bb80d99c3c19f996b0be48e9d47f716b0e Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Mon, 3 Feb 2025 09:35:37 -0800 Subject: [PATCH 4/6] proper outrank expressions --- include/matx/operators/solve.h | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/include/matx/operators/solve.h b/include/matx/operators/solve.h index 4116c0a9..bec5c3e9 100644 --- a/include/matx/operators/solve.h +++ b/include/matx/operators/solve.h @@ -47,8 +47,9 @@ class SolveOp : public BaseOp> { typename detail::base_type_t a_; typename detail::base_type_t b_; - cuda::std::array out_dims_; - mutable detail::tensor_impl_t tmp_out_; + static constexpr int out_rank = OpB::Rank(); + cuda::std::array out_dims_; + mutable detail::tensor_impl_t tmp_out_; mutable typename OpA::value_type *ptr = nullptr; public: @@ -58,7 +59,7 @@ class SolveOp : public BaseOp> { using value_type = typename OpA::value_type; __MATX_INLINE__ SolveOp(const OpA &a, const OpB &b) : a_(a), b_(b) { - for (int r = 0; r < Rank(); r++) { + for (int r = 0, rank = Rank(); r < rank; r++) { out_dims_[r] = b_.Size(r); } } From 76dcdd2d9ec4cec4fa937842770bcfa9f725d481 Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Mon, 3 Feb 2025 14:09:58 -0800 Subject: [PATCH 5/6] remove redundant pre/post proc --- include/matx/operators/solve.h | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/include/matx/operators/solve.h b/include/matx/operators/solve.h index bec5c3e9..31838374 100644 --- a/include/matx/operators/solve.h +++ b/include/matx/operators/solve.h @@ -104,9 +104,8 @@ class SolveOp : public BaseOp> { template __MATX_INLINE__ void InnerPreRun([[maybe_unused]] ShapeType &&shape, Executor &&ex) const noexcept { - if constexpr (is_matx_op()) { - a_.PreRun(std::forward(shape), std::forward(ex)); - } + static_assert(is_sparse_tensor_v, + "Direct solver currently only supports sparse system"); if constexpr (is_matx_op()) { b_.PreRun(std::forward(shape), std::forward(ex)); } @@ -124,9 +123,8 @@ class SolveOp : public BaseOp> { template __MATX_INLINE__ void PostRun(ShapeType &&shape, Executor &&ex) const noexcept { - if constexpr (is_matx_op()) { - a_.PostRun(std::forward(shape), std::forward(ex)); - } + static_assert(is_sparse_tensor_v, + "Direct solver currently only supports sparse system"); if constexpr (is_matx_op()) { b_.PostRun(std::forward(shape), std::forward(ex)); } From 4db6b288040352082f7b4b82f7151255a3c47c0f Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Mon, 3 Feb 2025 14:26:43 -0800 Subject: [PATCH 6/6] guard against unused --- include/matx/operators/solve.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/matx/operators/solve.h b/include/matx/operators/solve.h index 31838374..07423809 100644 --- a/include/matx/operators/solve.h +++ b/include/matx/operators/solve.h @@ -103,7 +103,7 @@ class SolveOp : public BaseOp> { template __MATX_INLINE__ void InnerPreRun([[maybe_unused]] ShapeType &&shape, - Executor &&ex) const noexcept { + [[maybe_unused]] Executor &&ex) const noexcept { static_assert(is_sparse_tensor_v, "Direct solver currently only supports sparse system"); if constexpr (is_matx_op()) { @@ -113,7 +113,7 @@ class SolveOp : public BaseOp> { template __MATX_INLINE__ void PreRun([[maybe_unused]] ShapeType &&shape, - Executor &&ex) const noexcept { + [[maybe_unused]] Executor &&ex) const noexcept { InnerPreRun(std::forward(shape), std::forward(ex)); detail::AllocateTempTensor(tmp_out_, std::forward(ex), out_dims_, &ptr); @@ -121,8 +121,8 @@ class SolveOp : public BaseOp> { } template - __MATX_INLINE__ void PostRun(ShapeType &&shape, - Executor &&ex) const noexcept { + __MATX_INLINE__ void PostRun([[maybe_unused]] ShapeType &&shape, + [[maybe_unused]]Executor &&ex) const noexcept { static_assert(is_sparse_tensor_v, "Direct solver currently only supports sparse system"); if constexpr (is_matx_op()) {