From 56b509fe8386e6e67b519e8cf6e1e6a96b4f81f1 Mon Sep 17 00:00:00 2001 From: "Aart J.C. Bik" Date: Wed, 22 Jan 2025 10:42:51 -0800 Subject: [PATCH] Guard all DIM/LVL recursion against completely empty format Although unlikely to be useful, the "Scalar" form is a valid storage format in the full space of storage formats that can be described by the DSL. This minor change ensures those cases do not result in compile-time errors. --- examples/sparse_tensor.cu | 18 +++++++- include/matx/core/sparse_tensor_format.h | 58 +++++++++++++----------- 2 files changed, 49 insertions(+), 27 deletions(-) diff --git a/examples/sparse_tensor.cu b/examples/sparse_tensor.cu index 40cca436..32bc58b8 100644 --- a/examples/sparse_tensor.cu +++ b/examples/sparse_tensor.cu @@ -41,6 +41,22 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) cudaStream_t stream = 0; cudaExecutor exec{stream}; + // + // Print some formats that are used for the versatile sparse tensor + // type. Note that common formats like COO and CSR have good library + // support in e.g. cuSPARSE, but MatX provides a much more general + // way to define the sparse tensor storage through a DSL (see doc). + // + experimental::Scalar::print(); // scalars + experimental::SpVec::print(); // sparse vectors + experimental::COO::print(); // various sparse matrix formats + experimental::CSR::print(); + experimental::CSC::print(); + experimental::DCSR::print(); + experimental::BSR<2,2>::print(); // 2x2 blocks + experimental::COO4::print(); // 4-dim tensor in COO + experimental::CSF5::print(); // 5-dim tensor in CSF + // // Creates a COO matrix for the following 4x8 dense matrix with 5 nonzero // elements, using the factory method that uses MatX tensors for the 1-dim @@ -71,7 +87,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) // // This shows: // - // tensor_impl_2_f32: Tensor{float} Rank: 2, Sizes:[4, 8], Levels:[4, 8] + // tensor_impl_2_f32: SparseTensor{float} Rank: 2, Sizes:[4, 8], Levels:[4, 8] // nse = 5 // format = ( d0, d1 ) -> ( d0 : compressed(non-unique), d1 : singleton ) // crd[0] = ( 0 0 3 3 3 ) diff --git a/include/matx/core/sparse_tensor_format.h b/include/matx/core/sparse_tensor_format.h index f5b733ef..cdba702d 100644 --- a/include/matx/core/sparse_tensor_format.h +++ b/include/matx/core/sparse_tensor_format.h @@ -255,42 +255,48 @@ template class SparseTensorFormat { template __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ static void dim2lvl(const CRD *dims, CRD *lvls, bool asSize) { - using ftype = std::tuple_element_t; - if constexpr (ftype::expr::op == LvlOp::Id) { - lvls[L] = dims[ftype::expr::di]; - } else if constexpr (ftype::expr::op == LvlOp::Div) { - lvls[L] = dims[ftype::expr::di] / ftype::expr::cj; - } else if constexpr (ftype::expr::op == LvlOp::Mod) { - lvls[L] = - asSize ? ftype::expr::cj : (dims[ftype::expr::di] % ftype::expr::cj); - } - if constexpr (L + 1 < LVL) { - dim2lvl(dims, lvls, asSize); + if constexpr (L < LVL) { + using ftype = std::tuple_element_t; + if constexpr (ftype::expr::op == LvlOp::Id) { + lvls[L] = dims[ftype::expr::di]; + } else if constexpr (ftype::expr::op == LvlOp::Div) { + lvls[L] = dims[ftype::expr::di] / ftype::expr::cj; + } else if constexpr (ftype::expr::op == LvlOp::Mod) { + lvls[L] = asSize ? ftype::expr::cj + : (dims[ftype::expr::di] % ftype::expr::cj); + } + if constexpr (L + 1 < LVL) { + dim2lvl(dims, lvls, asSize); + } } } template __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ static void lvl2dim(const CRD *lvls, CRD *dims) { - using ftype = std::tuple_element_t; - if constexpr (ftype::expr::op == LvlOp::Id) { - dims[ftype::expr::di] = lvls[L]; - } else if constexpr (ftype::expr::op == LvlOp::Div) { - dims[ftype::expr::di] = lvls[L] * ftype::expr::cj; - } else if constexpr (ftype::expr::op == LvlOp::Mod) { - dims[ftype::expr::di] += lvls[L]; // update (seen second) - } - if constexpr (L + 1 < LVL) { - lvl2dim(lvls, dims); + if constexpr (L < LVL) { + using ftype = std::tuple_element_t; + if constexpr (ftype::expr::op == LvlOp::Id) { + dims[ftype::expr::di] = lvls[L]; + } else if constexpr (ftype::expr::op == LvlOp::Div) { + dims[ftype::expr::di] = lvls[L] * ftype::expr::cj; + } else if constexpr (ftype::expr::op == LvlOp::Mod) { + dims[ftype::expr::di] += lvls[L]; // update (seen second) + } + if constexpr (L + 1 < LVL) { + lvl2dim(lvls, dims); + } } } template static void printLevel() { - using ftype = std::tuple_element_t; - std::cout << " " << ftype::toString(); - if constexpr (L + 1 < LVL) { - std::cout << ","; - printLevel(); + if constexpr (L < LVL) { + using ftype = std::tuple_element_t; + std::cout << " " << ftype::toString(); + if constexpr (L + 1 < LVL) { + std::cout << ","; + printLevel(); + } } }