Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extracting Q from QR decomposition in SPQR #85

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/polysolve/linear/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ set(SOURCES
Pardiso.hpp
SaddlePointSolver.cpp
SaddlePointSolver.hpp
SPQR.cpp
SPQR.hpp
)

source_group(TREE "${CMAKE_CURRENT_SOURCE_DIR}" PREFIX "Source Files" FILES ${SOURCES})
Expand Down
20 changes: 20 additions & 0 deletions src/polysolve/linear/SPQR.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "SPQR.hpp"

namespace polysolve::linear
{
template <>
void EigenDirect<Eigen::SPQR<StiffnessMatrix>>::analyze_pattern(const StiffnessMatrix &A, const int precond_num)
{
m_Solver.compute(A);
}
template <>
void EigenDirect<Eigen::SPQR<StiffnessMatrix>>::factorize(const StiffnessMatrix &A)
{
m_Solver.compute(A);
if (m_Solver.info() == Eigen::NumericalIssue)
{
throw std::runtime_error("[EigenDirect] NumericalIssue encountered.");
}
}

} // namespace polysolve::linear
96 changes: 96 additions & 0 deletions src/polysolve/linear/SPQR.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#pragma once
#ifdef POLYSOLVE_WITH_SPQR
#include <Eigen/Sparse>
#include <Eigen/SPQRSupport>
#include "EigenSolver.hpp"
#include "Solver.hpp"
namespace polysolve::linear
{
template <>
void EigenDirect<Eigen::SPQR<StiffnessMatrix>>::analyze_pattern(const StiffnessMatrix &A, const int precond_num);
template <>
void EigenDirect<Eigen::SPQR<StiffnessMatrix>>::factorize(const StiffnessMatrix &A);

class SPQRSolver : public EigenDirect<Eigen::SPQR<StiffnessMatrix>>
{

StiffnessMatrix matrixQ() const;
};
} // namespace polysolve::linear
namespace Eigen
{
template <typename SPQRType, typename Derived>
struct SPQR_QSparseProduct;
namespace internal
{
template <typename SPQRType, typename Derived>
struct traits<SPQR_QSparseProduct<SPQRType, Derived>>
{
typedef typename Derived::PlainObject ReturnType;
};
} // namespace internal
template <>
struct SPQRMatrixQReturnType<SPQR<polysolve::StiffnessMatrix>>
{

using SPQRType = SPQR<polysolve::StiffnessMatrix>;
SPQRMatrixQReturnType(const SPQRType &spqr) : m_spqr(spqr) {}
template <typename Derived>
SPQR_QProduct<SPQRType, Derived> operator*(const MatrixBase<Derived> &other)
{
return SPQR_QProduct<SPQRType, Derived>(m_spqr, other.derived(), false);
}
template <typename Derived>
SPQR_QSparseProduct<SPQRType, Derived> operator*(const SparseMatrixBase<Derived> &other)
{
return SPQR_QSparseProduct<SPQRType, Derived>(m_spqr, other.derived(), false);
}
SPQRMatrixQTransposeReturnType<SPQRType> adjoint() const
{
return SPQRMatrixQTransposeReturnType<SPQRType>(m_spqr);
}
// To use for operations with the transpose of Q
SPQRMatrixQTransposeReturnType<SPQRType> transpose() const
{
return SPQRMatrixQTransposeReturnType<SPQRType>(m_spqr);
}
const SPQRType &m_spqr;
};

template <typename SPQRType, typename Derived>
struct SPQR_QSparseProduct : ReturnByValue<SPQR_QSparseProduct<SPQRType, Derived>>
{
struct SPQRTypeWrap : public SPQRType
{
using SPQRType::m_H;
using SPQRType::m_HPinv;
using SPQRType::m_HTau;
};
typedef typename SPQRType::Scalar Scalar;
typedef typename SPQRType::StorageIndex StorageIndex;
// Define the constructor to get reference to argument types
SPQR_QSparseProduct(const SPQRType &spqr, const Derived &other, bool transpose) : m_spqr(spqr), m_other(other), m_transpose(transpose) {}

const SPQRTypeWrap &spqr_w() const { return reinterpret_cast<const SPQRTypeWrap &>(m_spqr); }

inline Index rows() const { return m_transpose ? m_spqr.rows() : m_spqr.cols(); }
inline Index cols() const { return m_other.cols(); }
// Assign to a vector
template <typename ResType>
void evalTo(ResType &res) const
{
cholmod_sparse y_cd;
cholmod_sparse *x_cd;
int method = m_transpose ? SPQR_QTX : SPQR_QX;
cholmod_common *cc = m_spqr.cholmodCommon();
y_cd = viewAsCholmod(m_other.const_cast_derived());
x_cd = SuiteSparseQR_qmult<Scalar>(method, spqr_w().m_H, spqr_w().m_HTau, spqr_w().m_HPinv, &y_cd, cc);
res = viewAsEigen<Scalar, ColMajor, StorageIndex>(*x_cd);
cholmod_l_free_sparse(&x_cd, cc);
}
const SPQRType &m_spqr;
const Derived &m_other;
bool m_transpose;
};
} // namespace Eigen
#endif
17 changes: 1 addition & 16 deletions src/polysolve/linear/Solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,7 @@ template <typename _MatrixType, typename _OrderingType = COLAMDOrdering<typename
#include <Eigen/UmfPackSupport>
#endif
#ifdef POLYSOLVE_WITH_SPQR
#include <Eigen/SPQRSupport>
namespace polysolve::linear {
template <>
void EigenDirect<Eigen::SPQR<StiffnessMatrix>>::analyze_pattern(const StiffnessMatrix& A, const int precond_num) {
m_Solver.compute(A);
}
template <>
void EigenDirect<Eigen::SPQR<StiffnessMatrix>>::factorize(const StiffnessMatrix &A)
{
m_Solver.compute(A);
if (m_Solver.info() == Eigen::NumericalIssue)
{
throw std::runtime_error("[EigenDirect] NumericalIssue encountered.");
}
}
}
#include "SPQR.hpp"
#endif
#ifdef POLYSOLVE_WITH_SUPERLU
#include <Eigen/SuperLUSupport>
Expand Down
4 changes: 4 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ endif()
include(polyfem-data)
target_link_libraries(unit_tests PRIVATE polyfem::data)

if(POLYSOLVE_WITH_SPQR)
target_link_libraries(unit_tests PRIVATE SuiteSparse::SPQR)
endif()

################################################################################
# Register tests
################################################################################
Expand Down
65 changes: 65 additions & 0 deletions tests/test_linear_solver.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//////////////////////////////////////////////////////////////////////////
#include <polysolve/Types.hpp>
#include <polysolve/linear/FEMSolver.hpp>
#include <polysolve/linear/SPQR.hpp>

#include <polysolve/Utils.hpp>

Expand Down Expand Up @@ -856,3 +857,67 @@ TEST_CASE("cusolverdn_5cubes", "[solver]")
REQUIRE(err < 1e-8);
}
}

TEST_CASE("spqr_sparse_product", "[solver]")
{

Eigen::MatrixXd A(4, 4);
for (int i = 0; i < 4; i++)
{
A(i, i) = 1.0;
}
A(0, 1) = 1.0;
A(3, 0) = 1.0;
std::unique_ptr<Solver> solver;
try
{
solver = Solver::create("Eigen::SPQR", "");
}
catch (const std::exception &)
{
return;
}

using Type = EigenDirect<Eigen::SPQR<polysolve::StiffnessMatrix>>;

Type *typed_solver = dynamic_cast<Type *>(solver.get());
REQUIRE(typed_solver != nullptr);
// solver->set_parameters(params);

for (int i = 0; i < 5; ++i)
{
A = Eigen::MatrixXd::Random(5, 5);
auto As = A.sparseView().eval();

// do a qr so i have a q
Eigen::SPQR<StiffnessMatrix> spqr(As);
auto Q = spqr.matrixQ();

// get a random matrix to multiply against
Eigen::MatrixXd B(A.rows(), 5);
B.setRandom();
Eigen::MatrixXd dense = Q * B;

// make a sparse version
auto Bs = B.sparseView().eval();
StiffnessMatrix sparse = Q * Bs;

// check that the result of the product is the same
CHECK((dense - sparse).norm() < 1e-10);

// try to extract the Q matrix as a dense matrix
Eigen::MatrixXd Id = Eigen::MatrixXd::Identity(A.rows(), A.rows());
Eigen::MatrixXd denseQ = Q * Id;

// use the product with B to get a weak equivalence once again
Eigen::MatrixXd dense2 = denseQ * B;
CHECK((dense2 - dense).norm() < 1e-10);

// now try using a sparse product
StiffnessMatrix I(A.rows(), A.rows());
I.setIdentity();
StiffnessMatrix myQ = Q * I;
StiffnessMatrix sparse2 = myQ * Bs;
CHECK((sparse2 - sparse).norm() < 1e-10);
}
}
Loading