Skip to content

Commit

Permalink
added callback
Browse files Browse the repository at this point in the history
  • Loading branch information
teseoch committed Feb 3, 2025
1 parent 1f92bba commit 11cbff3
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 8 deletions.
12 changes: 10 additions & 2 deletions src/polysolve/nonlinear/Criteria.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,27 @@ namespace polysolve::nonlinear

void Criteria::reset()
{
constexpr double NaN = std::numeric_limits<double>::quiet_NaN();

iterations = 0;
xDelta = 0;
fDelta = 0;
gradNorm = 0;
firstGradNorm = 0;
fDeltaCount = 0;
xDeltaDotGrad = 0;

energy = NaN;
alpha = NaN;
step = NaN;
}

void Criteria::print(std::ostream &os) const
{
os << print_message();
}
std::string Criteria::print_message() const {
std::string Criteria::print_message() const
{
return fmt::format(
"iters={:d} Δf={:g} ‖∇f‖={:g} ‖Δx‖={:g} Δx⋅∇f(x)={:g}",
iterations, fDelta, gradNorm, xDelta, xDeltaDotGrad);
Expand Down Expand Up @@ -64,7 +71,8 @@ namespace polysolve::nonlinear
return Status::Continue;
}

std::string_view status_message(Status s) {
std::string_view status_message(Status s)
{
switch (s)
{
case Status::NotStarted:
Expand Down
5 changes: 4 additions & 1 deletion src/polysolve/nonlinear/Criteria.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ namespace polysolve::nonlinear
double xDeltaDotGrad; ///< Dot product of parameter vector and gradient vector
unsigned fDeltaCount; ///< Number of steps where fDelta is satisfied

double alpha; ///< LS alpha
double energy; ///< Energy at the current step
double step; ///< alpha * grad.norm()
Criteria();

void reset();
Expand All @@ -50,7 +53,7 @@ namespace polysolve::nonlinear
Status checkConvergence(const Criteria &stop, const Criteria &current);

std::string_view status_message(Status s);
std::string criteria_message(const Criteria& s);
std::string criteria_message(const Criteria &s);

std::ostream &operator<<(std::ostream &os, const Status &s);

Expand Down
9 changes: 9 additions & 0 deletions src/polysolve/nonlinear/Solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ namespace polysolve::nonlinear
POLYSOLVE_SCOPED_STOPWATCH("compute objective function", obj_fun_time, m_logger);
energy = objFunc(x);
}
m_current.energy = energy;

if (!std::isfinite(energy))
{
Expand Down Expand Up @@ -413,6 +414,7 @@ namespace polysolve::nonlinear
POLYSOLVE_SCOPED_STOPWATCH("line search", line_search_time, m_logger);
rate = m_line_search->line_search(x, delta_x, objFunc);
}
m_current.alpha = rate;

if (std::isnan(rate))
{
Expand Down Expand Up @@ -466,6 +468,7 @@ namespace polysolve::nonlinear
// Post update
// -----------
const double step = (rate * delta_x).norm();
m_current.step = step;

// m_logger.debug("[{}][{}] rate={:g} ‖step‖={:g}",
// descent_strategy_name(), m_line_search->name(), rate, step);
Expand All @@ -481,6 +484,12 @@ namespace polysolve::nonlinear

m_current.fDeltaCount = (m_current.fDelta < m_stop.fDelta) ? (m_current.fDeltaCount + 1) : 0;

if (m_iteration_callback && m_iteration_callback(m_current))
{
m_status = Status::ObjectiveCustomStop;
m_logger.debug("[{}][{}] Iteration callback decided to stop", descent_strategy_name(), m_line_search->name());

Check warning on line 490 in src/polysolve/nonlinear/Solver.cpp

View check run for this annotation

Codecov / codecov/patch

src/polysolve/nonlinear/Solver.cpp#L489-L490

Added lines #L489 - L490 were not covered by tests
}

m_logger.debug(
"[{}][{}] {} (stopping criteria: {})",
descent_strategy_name(), m_line_search->name(), m_current.print_message(), m_stop.print_message());
Expand Down
13 changes: 8 additions & 5 deletions src/polysolve/nonlinear/Solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,16 @@ namespace polysolve::nonlinear
void set_line_search(const json &params);
const json &info() const { return solver_info; }

void set_iteration_callback(std::function<bool(const Criteria &)> callback) { m_iteration_callback = callback; }

/// @brief If true the solver will not throw an error if the maximum number of iterations is reached
bool allow_out_of_iterations = false;


/// @brief Get the line search object
const std::shared_ptr<line_search::LineSearch> &line_search() const { return m_line_search; };

protected:
/// @brief Compute direction in which the argument should be updated
/// @brief Compute direction in which the argument should be updated
/// @param objFunc Problem to be minimized
/// @param x Current input (n x 1)
/// @param grad Gradient at current step (n x 1)
Expand All @@ -106,8 +107,8 @@ namespace polysolve::nonlinear
Criteria m_stop;

/// @brief Current criteria
Criteria m_current;

Criteria m_current;
/// @brief Current status
Status m_status = Status::NotStarted;

Expand Down Expand Up @@ -139,7 +140,7 @@ namespace polysolve::nonlinear
// ====================================================================
// Solver state
// ====================================================================

/// @brief Reset the solver at the start of a minimization
/// @param ndof number of degrees of freedom
void reset(const int ndof);
Expand All @@ -151,12 +152,14 @@ namespace polysolve::nonlinear

std::vector<int> m_iter_per_strategy;

std::function<bool(const Criteria &)> m_iteration_callback = nullptr;

// ====================================================================
// Solver info
// ====================================================================

/// @brief Update solver info JSON object
/// @param energy
/// @param energy
void update_solver_info(const double energy);

/// @brief Reset timing members to 0
Expand Down

0 comments on commit 11cbff3

Please sign in to comment.