Skip to content

Commit

Permalink
Improve interface (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
junyuan-chen authored Oct 1, 2024
1 parent e68c841 commit 74e442f
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 28 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ FastLapackInterface = "29a986be-02c6-4525-aec4-84b980013641"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[compat]
CommonSolve = "0.2"
FastLapackInterface = "1.2"
FastLapackInterface = "1.2, 2"
NLSolversBase = "7.6"
PositiveFactorizations = "0.2.4"
PrecompileTools = "1"
Expand Down
2 changes: 1 addition & 1 deletion src/NonlinearSystems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using CommonSolve: solve
using FastLapackInterface: LUWs
using LinearAlgebra: BLAS, LAPACK, LU, Cholesky, cholesky!, Hermitian, ldiv!, mul!,
lowrankupdate!
using NLSolversBase: OnceDifferentiable, value_jacobian!!, jacobian!!
using NLSolversBase: OnceDifferentiable, value_jacobian!!, value!!, jacobian!!
using PositiveFactorizations
using Printf

Expand Down
29 changes: 18 additions & 11 deletions src/hybrid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct HybridSolver{T, L, V} <: AbstractSolver{T}
thres_jac::Int
thres_nslow1::Int
thres_nslow2::Int
warn::Bool
end

"""
Expand All @@ -79,13 +80,15 @@ See also [`Hybrid`](@ref).
- `thres_nslow2::Integer=5`: signal slow solver progress
if there is no expansion of trust region after recomputing the Jacobian matrix
in the specified number of consecutive steps.
- `warn::Bool=true`: print a warning message for slow solver progress
"""
function HybridSolver(fdf::OnceDifferentiable, x::AbstractVector, fx::AbstractVector,
J::AbstractMatrix, P::Type{<:ProblemType};
linsolver=default_linsolver(fdf, x0, P),
factor_init::Real=1.0, factor_up::Real=2.0, factor_down::Real=0.5,
scaling::Bool=true, rank1update::Bool=true,
thres_jac::Integer=2, thres_nslow1::Integer=10, thres_nslow2::Integer=5)
thres_jac::Integer=2, thres_nslow1::Integer=10, thres_nslow2::Integer=5,
warn::Bool=true)
M, N = size(J) # Assume the sizes of x, fx and J are all compatible
P === RootFinding && M != N && throw(DimensionMismatch(
"the number of variables must match the number of equations in a root-finding problem"))
Expand Down Expand Up @@ -114,7 +117,7 @@ function HybridSolver(fdf::OnceDifferentiable, x::AbstractVector, fx::AbstractVe
return HybridSolver(Ref(state), linsolver, diagn,
newton, grad, df, Jdx, w, v, factor_init, factor_up, factor_down,
scaling, rank1update, convert(Int, thres_jac),
convert(Int, thres_nslow1), convert(Int, thres_nslow2))
convert(Int, thres_nslow1), convert(Int, thres_nslow2), warn)
end

# ! This method reuses arrays from s
Expand All @@ -123,13 +126,15 @@ function init(s::NonlinearSystem{P,V,M,<:HybridSolver{T}}, x0::V;
factor_down::Real=s.solver.factor_down, scaling::Bool=s.solver.scaling,
rank1update::Bool=s.solver.rank1update, thres_jac::Integer=s.solver.thres_jac,
thres_nslow1::Integer=s.solver.thres_nslow1,
thres_nslow2::Integer=s.solver.thres_nslow2, kwargs...) where {P,V,M,T}
copyto!(s.x, x0)
thres_nslow2::Integer=s.solver.thres_nslow2,
warn::Bool=s.solver.warn,
initf::Bool=true, initdf::Bool=true, kwargs...) where {P,V,M,T}
s.x === x0 || copyto!(s.x, x0)
nan = convert(eltype(s.x), NaN)
fill!(s.dx, nan)
ss, fdf = s.solver, s.fdf
linsolver = getlinsolver(s)
init(linsolver, fdf, s.x)
init(linsolver, fdf, s.x; initf=initf, initdf=initdf)
fill!(ss.grad, nan)
copyto!(s.fx, fdf.F)
diagn = ss.diagn
Expand All @@ -146,8 +151,10 @@ function init(s::NonlinearSystem{P,V,M,<:HybridSolver{T}}, x0::V;
solver = HybridSolver(Ref(state), linsolver, diagn,
ss.newton, ss.grad, ss.df, ss.Jdx, ss.w, ss.v, factor_init, factor_up, factor_down,
scaling, rank1update, convert(Int, thres_jac),
convert(Int, thres_nslow1), convert(Int, thres_nslow2))
return NonlinearSystem(P, s.fdf, s.x, s.fx, s.dx, solver; kwargs...)
convert(Int, thres_nslow1), convert(Int, thres_nslow2), warn)
return NonlinearSystem(P, s.fdf, s.x, s.fx, s.dx, solver;
lower=s.lb, upper=s.ub, maxiter=s.maxiter, ftol=s.ftol, gtol=s.gtol,
xtol=s.xtol, xtolr=s.xtolr, showtrace=s.showtrace, kwargs...)
end

function dogleg!(dx, linsolver, J, fx, diagn, δ, newton, grad, w)
Expand Down Expand Up @@ -291,10 +298,10 @@ function (s::HybridSolver{T})(fdf::OnceDifferentiable, x::AbstractVector,

if nslow1 === s.thres_nslow1
if nslow2 >= s.thres_nslow2
@warn "iteration $(iter) is not making progress even with reevaluations of Jacobians; try a smaller value with option thres_jac"
s.warn && @warn "iteration $(iter) is not making progress even with reevaluations of Jacobians; try a smaller value with option thres_jac"
return iter, jac_noprogress
else
@warn "iteration $(iter) is not making progress"
s.warn && @warn "iteration $(iter) is not making progress"
return iter, eval_noprogress
end
else
Expand All @@ -306,15 +313,15 @@ function init(::Type{Hybrid{P}}, fdf::OnceDifferentiable, x0::AbstractVector;
linsolver=default_linsolver(fdf, x0, P),
factor_init::Real=1.0, factor_up::Real=2.0, factor_down::Real=0.5,
scaling=true, rank1update=true,
thres_jac=2, thres_nslow1=10, thres_nslow2=5, kwargs...) where P
thres_jac=2, thres_nslow1=10, thres_nslow2=5, warn=true, kwargs...) where P
x = copy(x0)
fx = copy(fdf.F)
dx = similar(x) # Preserve the array type
fill!(dx, convert(eltype(x), NaN))
solver = HybridSolver(fdf, x, fx, fdf.DF, P; linsolver=linsolver,
factor_init=factor_init, factor_up=factor_up, factor_down=factor_down,
scaling=scaling, rank1update=rank1update,
thres_jac=thres_jac, thres_nslow1=thres_nslow1, thres_nslow2=thres_nslow2)
thres_jac=thres_jac, thres_nslow1=thres_nslow1, thres_nslow2=thres_nslow2, warn=warn)
# Remaining kwargs are handled by NonlinearSystem constructor
return NonlinearSystem(P, fdf, x, fx, dx, solver; kwargs...)
end
Expand Down
10 changes: 6 additions & 4 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,13 @@ solve!(s::NonlinearSystem, x0; kwargs...) = solve!(init(s, x0; kwargs...))
init(algo::AbstractAlgorithm, args...; kwargs...) =
init(typeof(algo), args...; kwargs...)

init(Algo::Type{<:AbstractAlgorithm}, f, x0::AbstractVector; kwargs...) =
init(Algo, OnceDifferentiable(f, similar(x0), similar(x0)), x0; kwargs...)
init(Algo::Type{<:AbstractAlgorithm}, f, x0::AbstractVector, nf::Int=length(x0);
kwargs...) =
init(Algo, OnceDifferentiable(f, similar(x0), similar(x0, nf)), x0; kwargs...)

init(Algo::Type{<:AbstractAlgorithm}, f, j, x0::AbstractVector; kwargs...) =
init(Algo, OnceDifferentiable(f, j, similar(x0), similar(x0)), x0; kwargs...)
init(Algo::Type{<:AbstractAlgorithm}, f, j, x0::AbstractVector, nf::Int=length(x0);
kwargs...) =
init(Algo, OnceDifferentiable(f, j, similar(x0), similar(x0, nf)), x0; kwargs...)

function show(io::IO, s::NonlinearSystem{P}) where P
print(io, nequ(s), '×', nvar(s), ' ', typeof(s).name.name, '{', P, "}(")
Expand Down
20 changes: 12 additions & 8 deletions src/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ default_linsolver(fdf::OnceDifferentiable{V, M, V}, x0::V,
::Type{RootFinding}) where {V<:AbstractVector, M<:AbstractMatrix} =
init(DenseLUSolver, fdf, x0)

init(::Type{DenseLUSolver}, fdf::OnceDifferentiable, x0::AbstractVector) =
(_init!(fdf, x0); init(DenseLUSolver, fdf.DF, fdf.F))
init(::Type{DenseLUSolver}, fdf::OnceDifferentiable, x0::AbstractVector;
initf::Bool=true, initdf::Bool=true) =
(_init!(fdf, x0, initf, initdf); init(DenseLUSolver, fdf.DF, fdf.F))

function init(::Type{DenseLUSolver}, J::AbstractMatrix, f::AbstractVector)
ws = LUWs(J)
Expand All @@ -53,8 +54,9 @@ function init(::Type{DenseLUSolver}, J::AbstractMatrix, f::AbstractVector)
return DenseLUSolver(ws, fac)
end

init(s::DenseLUSolver, fdf::OnceDifferentiable, x0::AbstractVector) =
(_init!(fdf, x0); update!(s, fdf.DF); s)
init(s::DenseLUSolver, fdf::OnceDifferentiable, x0::AbstractVector;
initf::Bool=true, initdf::Bool=true) =
(_init!(fdf, x0, initf, initdf); update!(s, fdf.DF); s)

update!(s::DenseLUSolver{<:LU}, J::AbstractMatrix) =
(s.fac = LU(LAPACK.getrf!(s.ws, copyto!(getfield(s.fac, :factors), J))...); nothing)
Expand Down Expand Up @@ -144,8 +146,9 @@ default_linsolver(fdf::OnceDifferentiable{V, M, V}, x0::V,
::Type{LeastSquares}) where {V<:AbstractVector, M<:AbstractMatrix} =
init(DenseCholeskySolver, fdf, x0)

init(::Type{DenseCholeskySolver}, fdf::OnceDifferentiable, x0::AbstractVector; kwargs...) =
(_init!(fdf, x0); init(DenseCholeskySolver, fdf.DF, fdf.F; kwargs...))
init(::Type{DenseCholeskySolver}, fdf::OnceDifferentiable, x0::AbstractVector;
initf::Bool=true, initdf::Bool=true, kwargs...) =
(_init!(fdf, x0, initf, initdf); init(DenseCholeskySolver, fdf.DF, fdf.F; kwargs...))

function init(::Type{DenseCholeskySolver}, J::AbstractMatrix, f::AbstractVector;
rank1chol::Bool=length(f)>3)
Expand All @@ -160,8 +163,9 @@ function init(::Type{DenseCholeskySolver}, J::AbstractMatrix, f::AbstractVector;
end
end

init(s::DenseCholeskySolver, fdf::OnceDifferentiable{V, M, V}, x0::V) where {V, M} =
(_init!(fdf, x0); update!(s, fdf.DF); s)
init(s::DenseCholeskySolver, fdf::OnceDifferentiable{V, M, V}, x0::V;
initf::Bool=true, initdf::Bool=true) where {V, M} =
(_init!(fdf, x0, initf, initdf); update!(s, fdf.DF); s)

update!(s::DenseCholeskySolver, J::AbstractMatrix) =
(s.fac = _cholesky!(mul!(s.JtJ, J', J), s.d); return nothing)
Expand Down
10 changes: 8 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@ end
return sqrt(e2)
end

function _init!(fdf::OnceDifferentiable, x0)
function _init!(fdf::OnceDifferentiable, x0, initf::Bool, initdf::Bool)
fdf.f_calls[1] = 0
fdf.df_calls[1] = 0
value_jacobian!!(fdf, x0)
if initf && initdf
value_jacobian!!(fdf, x0)
elseif initf
value!!(fdf, x0)
elseif initdf
jacobian!!(fdf, x0)
end
return fdf
end

Expand Down
30 changes: 30 additions & 0 deletions test/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ end
s = default_linsolver(fdf, x0, RootFinding)
Y = zeros(2)
@test solve!(s, Y, copy(fdf.DF), copy(fdf.F)) fdf.DF \ fdf.F

x = rand(2)
init(s, fdf, x)
F = zeros(2)
J = zeros(2, 2)
@test fdf.F == f!(F, x)
@test fdf.DF == j!(J, x)
x1 = rand(2)
init(s, fdf, x1; initf=false)
@test fdf.F == f!(F, x)
@test fdf.DF == j!(J, x1)
x2 = rand(2)
init(s, fdf, x2; initdf=false)
@test fdf.F == f!(F, x2)
@test fdf.DF == j!(J, x1)
end

@testset "DenseCholeskySolver" begin
Expand Down Expand Up @@ -60,4 +75,19 @@ end
update!(s1, copy(J), copy(w), copy(v))
J1 = J .+ w .* v'
@test s1.fac.U's1.fac.U J1'J1

x = rand(2)
init(s, fdf, x)
F = zeros(2)
J = zeros(2, 2)
@test fdf.F == f!(F, x)
@test fdf.DF == j!(J, x)
x1 = rand(2)
init(s, fdf, x1; initf=false)
@test fdf.F == f!(F, x)
@test fdf.DF == j!(J, x1)
x2 = rand(2)
init(s, fdf, x2; initdf=false)
@test fdf.F == f!(F, x2)
@test fdf.DF == j!(J, x1)
end

0 comments on commit 74e442f

Please sign in to comment.