Skip to content

Commit

Permalink
Add SparseDiffTools extension
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Apr 9, 2024
1 parent f4af2a7 commit 01e36f7
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 3 deletions.
33 changes: 31 additions & 2 deletions ext/SparseConnectivityTracerSparseDiffToolsExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,35 @@
module SparseConnectivityTracerSparseDiffToolsExt

using SparseConnectivityTracer
using SparseDiffTools: SparseDiffTools
using SparseConnectivityTracer: connectivity
using SparseDiffTools:
AbstractSparseADType,
AbstractSparsityDetection,
ArrayInterface,
GreedyD1Color,
JacPrototypeSparsityDetection,
SparseDiffTools

Base.@kwdef struct ConnectivityTracerSparsityDetection{
A<:ArrayInterface.ColoringAlgorithm
} <: AbstractSparsityDetection
alg::A = GreedyD1Color()
end

function (alg::ConnectivityTracerSparsityDetection)(
ad::AbstractSparseADType, f, x; fx=nothing, kwargs...
)
fx = fx === nothing ? similar(f(x)) : dx
J = connectivity(f, x)
_alg = JacPrototypeSparsityDetection(J, alg.alg)
return _alg(ad, f, x; fx, kwargs...)
end

function (alg::ConnectivityTracerSparsityDetection)(
ad::AbstractSparseADType, f!, fx, x; kwargs...
)
J = connectivity(f!, fx, x)
_alg = JacPrototypeSparsityDetection(J, alg.alg)
return _alg(ad, f!, fx, x; kwargs...)
end

end
104 changes: 103 additions & 1 deletion test/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.2"
manifest_format = "2.0"
project_hash = "4486f686024f7421deb8134dae1d86752576d670"
project_hash = "b9592a2a9910b3a46570ba2c71ed9e1fe833e1a1"

[[deps.ADTypes]]
git-tree-sha1 = "016833eb52ba2d6bea9fcb50ca295980e728ee24"
Expand Down Expand Up @@ -60,6 +60,12 @@ version = "0.8.4"
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
version = "1.1.1"

[[deps.ArnoldiMethod]]
deps = ["LinearAlgebra", "Random", "StaticArrays"]
git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6"
uuid = "ec485272-7323-5ecc-a04f-4719b315124d"
version = "0.4.0"

[[deps.ArrayInterface]]
deps = ["Adapt", "LinearAlgebra", "SparseArrays", "SuiteSparse"]
git-tree-sha1 = "44691067188f6bd1b2289552a23e4b7572f4528d"
Expand Down Expand Up @@ -371,6 +377,22 @@ weakdeps = ["PDMats", "SparseArrays", "Statistics"]
FillArraysSparseArraysExt = "SparseArrays"
FillArraysStatisticsExt = "Statistics"

[[deps.FiniteDiff]]
deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays"]
git-tree-sha1 = "bc0c5092d6caaea112d3c8e3b238d61563c58d5f"
uuid = "6a86dc24-6348-571c-b903-95158fe2bd41"
version = "2.23.0"

[deps.FiniteDiff.extensions]
FiniteDiffBandedMatricesExt = "BandedMatrices"
FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices"
FiniteDiffStaticArraysExt = "StaticArrays"

[deps.FiniteDiff.weakdeps]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[[deps.FixedPointNumbers]]
deps = ["Statistics"]
git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc"
Expand Down Expand Up @@ -430,6 +452,12 @@ git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496"
uuid = "c27321d9-0574-5035-807b-f59d2c89b15c"
version = "1.3.1"

[[deps.Graphs]]
deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"]
git-tree-sha1 = "3863330da5466410782f2bffc64f3d505a6a8334"
uuid = "86223c79-3864-5bf0-83f7-82e725a168b6"
version = "1.10.0"

[[deps.HypergeometricFunctions]]
deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"]
git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685"
Expand Down Expand Up @@ -459,6 +487,11 @@ git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0"
uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534"
version = "0.10.2"

[[deps.Inflate]]
git-tree-sha1 = "ea8031dea4aff6bd41f1df8f2fdfb25b33626381"
uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9"
version = "0.1.4"

[[deps.InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand Down Expand Up @@ -792,6 +825,12 @@ git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.11.31"

[[deps.PackageExtensionCompat]]
git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518"
uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930"
version = "1.0.2"
weakdeps = ["Requires", "TOML"]

[[deps.PaddedViews]]
deps = ["OffsetArrays"]
git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f"
Expand Down Expand Up @@ -977,6 +1016,16 @@ git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac"
uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46"
version = "1.1.1"

[[deps.SharedArrays]]
deps = ["Distributed", "Mmap", "Random", "Serialization"]
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"

[[deps.SimpleTraits]]
deps = ["InteractiveUtils", "MacroTools"]
git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231"
uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
version = "0.9.4"

[[deps.Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

Expand All @@ -991,6 +1040,26 @@ deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
version = "1.10.0"

[[deps.SparseDiffTools]]
deps = ["ADTypes", "Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "Graphs", "LinearAlgebra", "PackageExtensionCompat", "Random", "Reexport", "SciMLOperators", "Setfield", "SparseArrays", "StaticArrayInterface", "StaticArrays", "Tricks", "UnPack", "VertexSafeGraphs"]
git-tree-sha1 = "a616ac46c38da60ac05cecf52064d44732edd05e"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
version = "2.17.0"

[deps.SparseDiffTools.extensions]
SparseDiffToolsEnzymeExt = "Enzyme"
SparseDiffToolsPolyesterExt = "Polyester"
SparseDiffToolsPolyesterForwardDiffExt = "PolyesterForwardDiff"
SparseDiffToolsSymbolicsExt = "Symbolics"
SparseDiffToolsZygoteExt = "Zygote"

[deps.SparseDiffTools.weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[[deps.SpecialFunctions]]
deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"]
git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d"
Expand All @@ -1007,6 +1076,23 @@ git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c"
uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
version = "0.1.1"

[[deps.Static]]
deps = ["IfElse"]
git-tree-sha1 = "d2fdac9ff3906e27f7a618d47b676941baa6c80c"
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
version = "0.8.10"

[[deps.StaticArrayInterface]]
deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "PrecompileTools", "Requires", "SparseArrays", "Static", "SuiteSparse"]
git-tree-sha1 = "5d66818a39bb04bf328e92bc933ec5b4ee88e436"
uuid = "0d7ed370-da01-4f52-bd93-41d350b8b718"
version = "1.5.0"
weakdeps = ["OffsetArrays", "StaticArrays"]

[deps.StaticArrayInterface.extensions]
StaticArrayInterfaceOffsetArraysExt = "OffsetArrays"
StaticArrayInterfaceStaticArraysExt = "StaticArrays"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2"
Expand Down Expand Up @@ -1148,6 +1234,11 @@ weakdeps = ["Random", "Test"]
[deps.TranscodingStreams.extensions]
TestExt = ["Test", "Random"]

[[deps.Tricks]]
git-tree-sha1 = "eae1bb484cd63b36999ee58be2de6c178105112f"
uuid = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
version = "0.1.8"

[[deps.URIs]]
git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b"
uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"
Expand All @@ -1157,6 +1248,11 @@ version = "1.5.1"
deps = ["Random", "SHA"]
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[[deps.UnPack]]
git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b"
uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
version = "1.0.2"

[[deps.Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

Expand All @@ -1177,6 +1273,12 @@ git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e"
uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249"
version = "0.1.3"

[[deps.VertexSafeGraphs]]
deps = ["Graphs"]
git-tree-sha1 = "8351f8d73d7e880bfc042a8b6922684ebeafb35c"
uuid = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
version = "0.2.0"

[[deps.XTermColors]]
deps = ["Crayons", "ImageBase", "OffsetArrays"]
git-tree-sha1 = "bc27b7622a51f570c57b80bd839d1c0d43605b38"
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ DocMeta.setdocmeta!(
@test C == C_ref
end
end
@testset "SparseDiffTools integration" begin
include("sparsedifftools.jl")
end
@testset "Doctests" begin
Documenter.doctest(SparseConnectivityTracer)
end
Expand Down
23 changes: 23 additions & 0 deletions test/sparsedifftools.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using Base: get_extension
using ForwardDiff: ForwardDiff
using SparseArrays
using SparseConnectivityTracer
using SparseDiffTools
using Test

ext = Base.get_extension(
SparseConnectivityTracer, :SparseConnectivityTracerSparseDiffToolsExt
)
@test !isnothing(ext)

sd = ext.ConnectivityTracerSparsityDetection()
adtype = SparseDiffTools.AutoSparseForwardDiff()

x = rand(10)
y = zeros(9)
J1 = sparse_jacobian(adtype, sd, diff, x)
J2 = sparse_jacobian(adtype, sd, (y, x) -> y .= diff(x), y, x)
@test J1 == J2
@test J1 isa SparseMatrixCSC
@test J2 isa SparseMatrixCSC
@test nnz(J1) == nnz(J2) == 18

0 comments on commit 01e36f7

Please sign in to comment.