Skip to content

Commit

Permalink
Merge pull request #704 from Pressio/gmres
Browse files Browse the repository at this point in the history
gmres for linear solver and matrix-free newton krylov solver
  • Loading branch information
fnrizzi authored Nov 13, 2024
2 parents 2ed7a87 + 43c54b3 commit 8c95f00
Show file tree
Hide file tree
Showing 21 changed files with 923 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
/*
//@HEADER
// ************************************************************************
//
// solvers_linear_eigen_iterative_matrix_free_impl.hpp
// Pressio
// Copyright 2019
// National Technology & Engineering Solutions of Sandia, LLC (NTESS)
//
// Under the terms of Contract DE-NA0003525 with NTESS, the
// U.S. Government retains certain rights in this software.
//
// Pressio is licensed under BSD-3-Clause terms of use:
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
// COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
// IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact Francesco Rizzi (fnrizzi@sandia.gov)
//
// ************************************************************************
//@HEADER
*/

#ifndef PRESSIO_SOLVERS_LINEAR_IMPL_SOLVERS_LINEAR_EIGEN_ITERATIVE_MATRIX_FREE_IMPL_HPP_
#define PRESSIO_SOLVERS_LINEAR_IMPL_SOLVERS_LINEAR_EIGEN_ITERATIVE_MATRIX_FREE_IMPL_HPP_

#include "solvers_linear_iterative_base.hpp"
#include <Eigen/Core>
#include <Eigen/Dense>

namespace pressio { namespace linearsolvers{

template<typename UserDefinedLinearOperatorType>
class OperatorWrapper;

}}

namespace Eigen {
namespace internal {
template<typename UserDefinedLinearOperatorType>
struct traits< pressio::linearsolvers::OperatorWrapper<UserDefinedLinearOperatorType> >
: public Eigen::internal::traits<
Eigen::Matrix<typename UserDefinedLinearOperatorType::scalar_type,-1,-1>
>
{};
}
}

namespace pressio { namespace linearsolvers{

template<typename UserDefinedLinearOperatorType>
class OperatorWrapper :
public Eigen::EigenBase<
OperatorWrapper<UserDefinedLinearOperatorType>
>
{
public:
using Scalar = typename UserDefinedLinearOperatorType::scalar_type;
using RealScalar = Scalar;
using StorageIndex = int;
enum {
ColsAtCompileTime = Eigen::Dynamic,
MaxColsAtCompileTime = Eigen::Dynamic,
};

OperatorWrapper() = default;

OperatorWrapper(UserDefinedLinearOperatorType const & valueIn)
: m_userOperator(&valueIn) {}


int rows() const { return m_userOperator->rows(); }
int cols() const { return m_userOperator->cols(); }

template<typename Rhs>
Eigen::Product<OperatorWrapper<UserDefinedLinearOperatorType>, Rhs, Eigen::AliasFreeProduct>
operator*(const Eigen::MatrixBase<Rhs>& x) const{
using r_t = Eigen::Product<
OperatorWrapper<UserDefinedLinearOperatorType>, Rhs, Eigen::AliasFreeProduct
>;
return r_t(*this, x.derived());
}

void replace(const UserDefinedLinearOperatorType & opIn) {
m_userOperator = &opIn;
}

template<class OperandT, class ResultT>
void applyAndAddTo(OperandT const & operand, ResultT & out) const {
// compute: out += operator * operand
m_userOperator->applyAndAddTo(operand, out);
}

private:
UserDefinedLinearOperatorType const *m_userOperator = nullptr;
};

}} // end namespace pressio::linearsolvers

namespace Eigen {
namespace internal {

template<typename Rhs, typename UserDefinedOpT>
struct generic_product_impl<
pressio::linearsolvers::OperatorWrapper<UserDefinedOpT>,
Rhs, DenseShape, DenseShape, GemvProduct
>
: generic_product_impl_base<
pressio::linearsolvers::OperatorWrapper<UserDefinedOpT>, Rhs,
generic_product_impl<pressio::linearsolvers::OperatorWrapper<UserDefinedOpT>, Rhs>
>
{
using Scalar = typename Product<
pressio::linearsolvers::OperatorWrapper<UserDefinedOpT>,
Rhs
>::Scalar;

template<typename Dest>
static void scaleAndAddTo(
Dest& dst,
const pressio::linearsolvers::OperatorWrapper<UserDefinedOpT> & lhs,
const Rhs& rhs,
const Scalar& alpha)
{
// This method should implement "dst += alpha * lhs * rhs" inplace,
// however, for iterative solvers, alpha is always equal to 1,
// so let's not bother about it.
assert(alpha==Scalar(1) && "scaling is not implemented");
EIGEN_ONLY_USED_FOR_DEBUG(alpha);

lhs.applyAndAddTo(rhs, dst);
}
};
}
}

namespace pressio { namespace linearsolvers{ namespace impl{

template<typename TagType, typename UserDefinedLinearOperatorType>
class EigenIterativeMatrixFree
: public IterativeBase<
EigenIterativeMatrixFree<TagType, UserDefinedLinearOperatorType>
>
{

public:
using this_type = EigenIterative<TagType, UserDefinedLinearOperatorType>;
using scalar_type = typename UserDefinedLinearOperatorType::scalar_type;
using solver_traits = ::pressio::linearsolvers::Traits<TagType>;
using op_wrapper_t = OperatorWrapper<UserDefinedLinearOperatorType>;
using native_solver_type = typename solver_traits::template eigen_solver_type<op_wrapper_t>;
using base_iterative_type = IterativeBase<this_type>;
using iteration_type = typename base_iterative_type::iteration_type;

static_assert(solver_traits::eigen_enabled == true,
"the native solver must be from Eigen to use in EigenIterativeMatrixFree");
static_assert(solver_traits::direct == false,
"The native eigen solver must be iterative to use in EigenIterativeMatrixFree");

public:
EigenIterativeMatrixFree() = default;

iteration_type numIterationsExecuted() const{
return mysolver_.iterations();
}

scalar_type finalError() const{
return mysolver_.error();
}

void resetLinearSystem(const UserDefinedLinearOperatorType& A)
{
mysolver_.setMaxIterations(this->maxIters_);
m_wrapper.replace(A);
mysolver_.compute(m_wrapper);
}

template <typename T>
void solve(const T& b, T & y){
mysolver_.setMaxIterations(this->maxIters_);
y = mysolver_.solve(b);
}

template <typename T>
void solve(const UserDefinedLinearOperatorType & A, const T& b, T & y){
this->resetLinearSystem(A);
this->solve(b, y);
}

private:
friend base_iterative_type;
native_solver_type mysolver_ = {};
op_wrapper_t m_wrapper;
};

}}} // end namespace pressio::solvers::iterarive::impl
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#ifdef PRESSIO_ENABLE_TPL_EIGEN
#include "solvers_linear_eigen_direct_impl.hpp"
#include "solvers_linear_eigen_iterative_impl.hpp"
#include "solvers_linear_eigen_iterative_matrix_free_impl.hpp"
#endif
#ifdef PRESSIO_ENABLE_TPL_KOKKOS
#include "solvers_linear_kokkos_direct_geqrf_impl.hpp"
Expand All @@ -69,6 +70,16 @@ struct Selector{
};

#ifdef PRESSIO_ENABLE_TPL_EIGEN
template<typename UserDefinedLinearOperatorType>
struct Selector<
iterative::GMRES, UserDefinedLinearOperatorType, void
>
{
using tag_t = iterative::GMRES;
using solver_traits = ::pressio::linearsolvers::Traits<tag_t>;
using type = EigenIterativeMatrixFree<tag_t, UserDefinedLinearOperatorType>;
};

template<typename TagType, typename MatrixType>
struct Selector<
TagType, MatrixType,
Expand Down
18 changes: 18 additions & 0 deletions include/pressio/solvers_linear/impl/solvers_linear_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include <Eigen/Sparse>
#include <Eigen/SparseQR>
#include <Eigen/OrderingMethods>
#include <unsupported/Eigen/IterativeSolvers>
#endif

namespace pressio{ namespace linearsolvers{
Expand All @@ -73,6 +74,23 @@ struct Traits {
#endif
};

template <>
struct Traits<::pressio::linearsolvers::iterative::GMRES>
{
static constexpr bool direct = false;
static constexpr bool iterative = true;

#ifdef PRESSIO_ENABLE_TPL_EIGEN
template <
typename MatrixOrOperatorT,
typename PrecT = Eigen::IdentityPreconditioner
>
using eigen_solver_type = Eigen::GMRES<MatrixOrOperatorT, PrecT>;

static constexpr bool eigen_enabled = true;
#endif
};

template <>
struct Traits<::pressio::linearsolvers::iterative::CG>
{
Expand Down
1 change: 1 addition & 0 deletions include/pressio/solvers_linear/solvers_linear_tags.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ namespace iterative{
struct CG {};
struct LSCG {};
struct Bicgstab {};
struct GMRES{};
}

namespace direct{
Expand Down
17 changes: 17 additions & 0 deletions include/pressio/solvers_nonlinear/impl/functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@ void compute_residual(RegistryType & reg,
system.residual(state, r);
}

#ifdef PRESSIO_ENABLE_CXX20
template<class RegistryType, class SystemType>
requires NonlinearSystem<SystemType>
#else
template<
class RegistryType, class SystemType,
std::enable_if_t< NonlinearSystem<SystemType>::value, int> = 0
>
#endif
void compute_residual(RegistryType & reg,
const SystemType & system)
{
const auto & state = reg.template get<StateTag>();
auto & r = reg.template get<ResidualTag>();
system.residual(state, r);
}

template<class RegistryType, class SystemType>
void compute_residual_and_jacobian(RegistryType & reg,
const SystemType & system)
Expand Down
1 change: 1 addition & 0 deletions include/pressio/solvers_nonlinear/impl/internal_tags.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ struct QTransposeResidualTag{};


struct NewtonTag{};
struct MatrixFreeNewtonTag{};
struct GaussNewtonNormalEqTag{};
struct WeightedGaussNewtonNormalEqTag{};
struct LevenbergMarquardtNormalEqTag{};
Expand Down
38 changes: 38 additions & 0 deletions include/pressio/solvers_nonlinear/impl/registries.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,44 @@ class RegistryNewton
GETMETHOD(6)
};

template<class SystemType, class LinearSolverTag>
class RegistryMatrixFreeNewtonKrylov
{
using state_t = typename SystemType::state_type;
using r_t = typename SystemType::residual_type;

using Tag1 = nonlinearsolvers::CorrectionTag;
using Tag2 = nonlinearsolvers::InitialGuessTag;
using Tag3 = nonlinearsolvers::ResidualTag;
using Tag4 = nonlinearsolvers::impl::SystemTag;

state_t d1_;
state_t d2_;
r_t d3_;
SystemType const * d4_;

public:
using linear_solver_tag = LinearSolverTag;

RegistryMatrixFreeNewtonKrylov(const SystemType & system)
: d1_(system.createState()),
d2_(system.createState()),
d3_(system.createResidual()),
d4_(&system){}

template<class TagToFind>
static constexpr bool contains(){
return (mpl::variadic::find_if_binary_pred_t<TagToFind, std::is_same,
Tag1, Tag2, Tag3, Tag4>::value) < 4;
}

GETMETHOD(1)
GETMETHOD(2)
GETMETHOD(3)
GETMETHOD(4)
};


template<class SystemType, class InnSolverType>
class RegistryGaussNewtonNormalEqs
{
Expand Down
Loading

0 comments on commit 8c95f00

Please sign in to comment.