Skip to content

Commit

Permalink
add newton solver with gmres
Browse files Browse the repository at this point in the history
  • Loading branch information
Francesco Rizzi committed Nov 13, 2024
1 parent 3b56abb commit 43c54b3
Show file tree
Hide file tree
Showing 17 changed files with 376 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,34 +55,32 @@

namespace pressio { namespace linearsolvers{

template<typename UserDefinedOperatorType>
template<typename UserDefinedLinearOperatorType>
class OperatorWrapper;

}}

namespace Eigen {
namespace internal {

template<typename UserDefinedOperatorType>
struct traits< pressio::linearsolvers::OperatorWrapper<UserDefinedOperatorType> >
template<typename UserDefinedLinearOperatorType>
struct traits< pressio::linearsolvers::OperatorWrapper<UserDefinedLinearOperatorType> >
: public Eigen::internal::traits<
Eigen::Matrix<typename UserDefinedOperatorType::scalar_type,-1,-1>
Eigen::Matrix<typename UserDefinedLinearOperatorType::scalar_type,-1,-1>
>
{};

}
}

namespace pressio { namespace linearsolvers{

template<typename UserDefinedOperatorType>
template<typename UserDefinedLinearOperatorType>
class OperatorWrapper :
public Eigen::EigenBase<
OperatorWrapper<UserDefinedOperatorType>
OperatorWrapper<UserDefinedLinearOperatorType>
>
{
public:
using Scalar = typename UserDefinedOperatorType::scalar_type;
using Scalar = typename UserDefinedLinearOperatorType::scalar_type;
using RealScalar = Scalar;
using StorageIndex = int;
enum {
Expand All @@ -92,35 +90,34 @@ class OperatorWrapper :

OperatorWrapper() = default;

OperatorWrapper(UserDefinedOperatorType const & valueIn)
OperatorWrapper(UserDefinedLinearOperatorType const & valueIn)
: m_userOperator(&valueIn) {}


int rows() const { return m_userOperator->rows(); }
int cols() const { return m_userOperator->cols(); }

template<typename Rhs>
Eigen::Product<OperatorWrapper<UserDefinedOperatorType>, Rhs, Eigen::AliasFreeProduct>
Eigen::Product<OperatorWrapper<UserDefinedLinearOperatorType>, Rhs, Eigen::AliasFreeProduct>
operator*(const Eigen::MatrixBase<Rhs>& x) const{
using r_t = Eigen::Product<OperatorWrapper<UserDefinedOperatorType>, Rhs, Eigen::AliasFreeProduct>;
using r_t = Eigen::Product<
OperatorWrapper<UserDefinedLinearOperatorType>, Rhs, Eigen::AliasFreeProduct
>;
return r_t(*this, x.derived());
}

void replace(const UserDefinedOperatorType & opIn) {
void replace(const UserDefinedLinearOperatorType & opIn) {
m_userOperator = &opIn;
}

template<class OperandT, class ResultT>
void applyAndAddTo(OperandT const & operand, ResultT & out) const {
//
// compute: out += operator * operand

// out += *m_userOperator * operand;
m_userOperator->applyAndAddTo(operand, out);
}

private:
UserDefinedOperatorType const *m_userOperator = nullptr;
UserDefinedLinearOperatorType const *m_userOperator = nullptr;
};

}} // end namespace pressio::linearsolvers
Expand All @@ -134,8 +131,8 @@ namespace Eigen {
Rhs, DenseShape, DenseShape, GemvProduct
>
: generic_product_impl_base<
pressio::linearsolvers::OperatorWrapper<UserDefinedOpT>, Rhs,
generic_product_impl<pressio::linearsolvers::OperatorWrapper<UserDefinedOpT>, Rhs>
pressio::linearsolvers::OperatorWrapper<UserDefinedOpT>, Rhs,
generic_product_impl<pressio::linearsolvers::OperatorWrapper<UserDefinedOpT>, Rhs>
>
{
using Scalar = typename Product<
Expand Down Expand Up @@ -164,17 +161,18 @@ namespace Eigen {

namespace pressio { namespace linearsolvers{ namespace impl{

template<typename TagType, typename UserDefinedOperatorType>
template<typename TagType, typename UserDefinedLinearOperatorType>
class EigenIterativeMatrixFree
: public IterativeBase< EigenIterativeMatrixFree<TagType, UserDefinedOperatorType>>
: public IterativeBase<
EigenIterativeMatrixFree<TagType, UserDefinedLinearOperatorType>
>
{

public:
using this_type = EigenIterative<TagType, UserDefinedOperatorType>;
// using matrix_type = UserDefinedOperatorType;
using scalar_type = typename UserDefinedOperatorType::scalar_type;
using this_type = EigenIterative<TagType, UserDefinedLinearOperatorType>;
using scalar_type = typename UserDefinedLinearOperatorType::scalar_type;
using solver_traits = ::pressio::linearsolvers::Traits<TagType>;
using op_wrapper_t = OperatorWrapper<UserDefinedOperatorType>;
using op_wrapper_t = OperatorWrapper<UserDefinedLinearOperatorType>;
using native_solver_type = typename solver_traits::template eigen_solver_type<op_wrapper_t>;
using base_iterative_type = IterativeBase<this_type>;
using iteration_type = typename base_iterative_type::iteration_type;
Expand All @@ -195,7 +193,7 @@ class EigenIterativeMatrixFree
return mysolver_.error();
}

void resetLinearSystem(const UserDefinedOperatorType& A)
void resetLinearSystem(const UserDefinedLinearOperatorType& A)
{
mysolver_.setMaxIterations(this->maxIters_);
m_wrapper.replace(A);
Expand All @@ -209,7 +207,7 @@ class EigenIterativeMatrixFree
}

template <typename T>
void solve(const UserDefinedOperatorType & A, const T& b, T & y){
void solve(const UserDefinedLinearOperatorType & A, const T& b, T & y){
this->resetLinearSystem(A);
this->solve(b, y);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ struct Selector{
};

#ifdef PRESSIO_ENABLE_TPL_EIGEN
template<typename UserDefinedOperatorType>
template<typename UserDefinedLinearOperatorType>
struct Selector<
iterative::GMRES, UserDefinedOperatorType, void
iterative::GMRES, UserDefinedLinearOperatorType, void
>
{
using tag_t = iterative::GMRES;
using solver_traits = ::pressio::linearsolvers::Traits<tag_t>;
using type = EigenIterativeMatrixFree<tag_t, UserDefinedOperatorType>;
using type = EigenIterativeMatrixFree<tag_t, UserDefinedLinearOperatorType>;
};

template<typename TagType, typename MatrixType>
Expand Down
17 changes: 17 additions & 0 deletions include/pressio/solvers_nonlinear/impl/functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@ void compute_residual(RegistryType & reg,
system.residual(state, r);
}

#ifdef PRESSIO_ENABLE_CXX20
template<class RegistryType, class SystemType>
requires NonlinearSystem<SystemType>
#else
template<
class RegistryType, class SystemType,
std::enable_if_t< NonlinearSystem<SystemType>::value, int> = 0
>
#endif
void compute_residual(RegistryType & reg,
const SystemType & system)
{
const auto & state = reg.template get<StateTag>();
auto & r = reg.template get<ResidualTag>();
system.residual(state, r);
}

template<class RegistryType, class SystemType>
void compute_residual_and_jacobian(RegistryType & reg,
const SystemType & system)
Expand Down
1 change: 1 addition & 0 deletions include/pressio/solvers_nonlinear/impl/internal_tags.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ struct QTransposeResidualTag{};


struct NewtonTag{};
struct MatrixFreeNewtonTag{};
struct GaussNewtonNormalEqTag{};
struct WeightedGaussNewtonNormalEqTag{};
struct LevenbergMarquardtNormalEqTag{};
Expand Down
38 changes: 38 additions & 0 deletions include/pressio/solvers_nonlinear/impl/registries.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,44 @@ class RegistryNewton
GETMETHOD(6)
};

template<class SystemType, class LinearSolverTag>
class RegistryMatrixFreeNewtonKrylov
{
using state_t = typename SystemType::state_type;
using r_t = typename SystemType::residual_type;

using Tag1 = nonlinearsolvers::CorrectionTag;
using Tag2 = nonlinearsolvers::InitialGuessTag;
using Tag3 = nonlinearsolvers::ResidualTag;
using Tag4 = nonlinearsolvers::impl::SystemTag;

state_t d1_;
state_t d2_;
r_t d3_;
SystemType const * d4_;

public:
using linear_solver_tag = LinearSolverTag;

RegistryMatrixFreeNewtonKrylov(const SystemType & system)
: d1_(system.createState()),
d2_(system.createState()),
d3_(system.createResidual()),
d4_(&system){}

template<class TagToFind>
static constexpr bool contains(){
return (mpl::variadic::find_if_binary_pred_t<TagToFind, std::is_same,
Tag1, Tag2, Tag3, Tag4>::value) < 4;
}

GETMETHOD(1)
GETMETHOD(2)
GETMETHOD(3)
GETMETHOD(4)
};


template<class SystemType, class InnSolverType>
class RegistryGaussNewtonNormalEqs
{
Expand Down
Loading

0 comments on commit 43c54b3

Please sign in to comment.