From 01e36f7783d990ed94b40362ba0ace93c7c38686 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 9 Apr 2024 18:23:00 +0200 Subject: [PATCH] Add SparseDiffTools extension --- ...rseConnectivityTracerSparseDiffToolsExt.jl | 33 +++++- test/Manifest.toml | 104 +++++++++++++++++- test/Project.toml | 2 + test/runtests.jl | 3 + test/sparsedifftools.jl | 23 ++++ 5 files changed, 162 insertions(+), 3 deletions(-) create mode 100644 test/sparsedifftools.jl diff --git a/ext/SparseConnectivityTracerSparseDiffToolsExt.jl b/ext/SparseConnectivityTracerSparseDiffToolsExt.jl index dcce25b0..e646414c 100644 --- a/ext/SparseConnectivityTracerSparseDiffToolsExt.jl +++ b/ext/SparseConnectivityTracerSparseDiffToolsExt.jl @@ -1,6 +1,35 @@ module SparseConnectivityTracerSparseDiffToolsExt -using SparseConnectivityTracer -using SparseDiffTools: SparseDiffTools +using SparseConnectivityTracer: connectivity +using SparseDiffTools: + AbstractSparseADType, + AbstractSparsityDetection, + ArrayInterface, + GreedyD1Color, + JacPrototypeSparsityDetection, + SparseDiffTools + +Base.@kwdef struct ConnectivityTracerSparsityDetection{ + A<:ArrayInterface.ColoringAlgorithm +} <: AbstractSparsityDetection + alg::A = GreedyD1Color() +end + +function (alg::ConnectivityTracerSparsityDetection)( + ad::AbstractSparseADType, f, x; fx=nothing, kwargs... +) + fx = fx === nothing ? similar(f(x)) : dx + J = connectivity(f, x) + _alg = JacPrototypeSparsityDetection(J, alg.alg) + return _alg(ad, f, x; fx, kwargs...) +end + +function (alg::ConnectivityTracerSparsityDetection)( + ad::AbstractSparseADType, f!, fx, x; kwargs... +) + J = connectivity(f!, fx, x) + _alg = JacPrototypeSparsityDetection(J, alg.alg) + return _alg(ad, f!, fx, x; kwargs...) +end end diff --git a/test/Manifest.toml b/test/Manifest.toml index 43d684fa..e22b5017 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.2" manifest_format = "2.0" -project_hash = "4486f686024f7421deb8134dae1d86752576d670" +project_hash = "b9592a2a9910b3a46570ba2c71ed9e1fe833e1a1" [[deps.ADTypes]] git-tree-sha1 = "016833eb52ba2d6bea9fcb50ca295980e728ee24" @@ -60,6 +60,12 @@ version = "0.8.4" uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" version = "1.1.1" +[[deps.ArnoldiMethod]] +deps = ["LinearAlgebra", "Random", "StaticArrays"] +git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" +uuid = "ec485272-7323-5ecc-a04f-4719b315124d" +version = "0.4.0" + [[deps.ArrayInterface]] deps = ["Adapt", "LinearAlgebra", "SparseArrays", "SuiteSparse"] git-tree-sha1 = "44691067188f6bd1b2289552a23e4b7572f4528d" @@ -371,6 +377,22 @@ weakdeps = ["PDMats", "SparseArrays", "Statistics"] FillArraysSparseArraysExt = "SparseArrays" FillArraysStatisticsExt = "Statistics" +[[deps.FiniteDiff]] +deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays"] +git-tree-sha1 = "bc0c5092d6caaea112d3c8e3b238d61563c58d5f" +uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" +version = "2.23.0" + + [deps.FiniteDiff.extensions] + FiniteDiffBandedMatricesExt = "BandedMatrices" + FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices" + FiniteDiffStaticArraysExt = "StaticArrays" + + [deps.FiniteDiff.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + [[deps.FixedPointNumbers]] deps = ["Statistics"] git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" @@ -430,6 +452,12 @@ git-tree-sha1 = "97285bbd5230dd766e9ef6749b80fc617126d496" uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" version = "1.3.1" +[[deps.Graphs]] +deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] +git-tree-sha1 = "3863330da5466410782f2bffc64f3d505a6a8334" +uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" +version = "1.10.0" + [[deps.HypergeometricFunctions]] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" @@ -459,6 +487,11 @@ git-tree-sha1 = "b2a7eaa169c13f5bcae8131a83bc30eff8f71be0" uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" version = "0.10.2" +[[deps.Inflate]] +git-tree-sha1 = "ea8031dea4aff6bd41f1df8f2fdfb25b33626381" +uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" +version = "0.1.4" + [[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -792,6 +825,12 @@ git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" version = "0.11.31" +[[deps.PackageExtensionCompat]] +git-tree-sha1 = "fb28e33b8a95c4cee25ce296c817d89cc2e53518" +uuid = "65ce6f38-6b18-4e1d-a461-8949797d7930" +version = "1.0.2" +weakdeps = ["Requires", "TOML"] + [[deps.PaddedViews]] deps = ["OffsetArrays"] git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" @@ -977,6 +1016,16 @@ git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" version = "1.1.1" +[[deps.SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[deps.SimpleTraits]] +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" +uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" +version = "0.9.4" + [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -991,6 +1040,26 @@ deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" version = "1.10.0" +[[deps.SparseDiffTools]] +deps = ["ADTypes", "Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "Graphs", "LinearAlgebra", "PackageExtensionCompat", "Random", "Reexport", "SciMLOperators", "Setfield", "SparseArrays", "StaticArrayInterface", "StaticArrays", "Tricks", "UnPack", "VertexSafeGraphs"] +git-tree-sha1 = "a616ac46c38da60ac05cecf52064d44732edd05e" +uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" +version = "2.17.0" + + [deps.SparseDiffTools.extensions] + SparseDiffToolsEnzymeExt = "Enzyme" + SparseDiffToolsPolyesterExt = "Polyester" + SparseDiffToolsPolyesterForwardDiffExt = "PolyesterForwardDiff" + SparseDiffToolsSymbolicsExt = "Symbolics" + SparseDiffToolsZygoteExt = "Zygote" + + [deps.SparseDiffTools.weakdeps] + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" + PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b" + Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" @@ -1007,6 +1076,23 @@ git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" version = "0.1.1" +[[deps.Static]] +deps = ["IfElse"] +git-tree-sha1 = "d2fdac9ff3906e27f7a618d47b676941baa6c80c" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.8.10" + +[[deps.StaticArrayInterface]] +deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "PrecompileTools", "Requires", "SparseArrays", "Static", "SuiteSparse"] +git-tree-sha1 = "5d66818a39bb04bf328e92bc933ec5b4ee88e436" +uuid = "0d7ed370-da01-4f52-bd93-41d350b8b718" +version = "1.5.0" +weakdeps = ["OffsetArrays", "StaticArrays"] + + [deps.StaticArrayInterface.extensions] + StaticArrayInterfaceOffsetArraysExt = "OffsetArrays" + StaticArrayInterfaceStaticArraysExt = "StaticArrays" + [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" @@ -1148,6 +1234,11 @@ weakdeps = ["Random", "Test"] [deps.TranscodingStreams.extensions] TestExt = ["Test", "Random"] +[[deps.Tricks]] +git-tree-sha1 = "eae1bb484cd63b36999ee58be2de6c178105112f" +uuid = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" +version = "0.1.8" + [[deps.URIs]] git-tree-sha1 = "67db6cc7b3821e19ebe75791a9dd19c9b1188f2b" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" @@ -1157,6 +1248,11 @@ version = "1.5.1" deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +[[deps.UnPack]] +git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" +uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +version = "1.0.2" + [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" @@ -1177,6 +1273,12 @@ git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" version = "0.1.3" +[[deps.VertexSafeGraphs]] +deps = ["Graphs"] +git-tree-sha1 = "8351f8d73d7e880bfc042a8b6922684ebeafb35c" +uuid = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f" +version = "0.2.0" + [[deps.XTermColors]] deps = ["Crayons", "ImageBase", "OffsetArrays"] git-tree-sha1 = "bc27b7622a51f570c57b80bd839d1c0d43605b38" diff --git a/test/Project.toml b/test/Project.toml index 132772e2..3c77b741 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,11 +1,13 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf" +SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index d06e398c..2084e43d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -89,6 +89,9 @@ DocMeta.setdocmeta!( @test C == C_ref end end + @testset "SparseDiffTools integration" begin + include("sparsedifftools.jl") + end @testset "Doctests" begin Documenter.doctest(SparseConnectivityTracer) end diff --git a/test/sparsedifftools.jl b/test/sparsedifftools.jl new file mode 100644 index 00000000..fa4caf32 --- /dev/null +++ b/test/sparsedifftools.jl @@ -0,0 +1,23 @@ +using Base: get_extension +using ForwardDiff: ForwardDiff +using SparseArrays +using SparseConnectivityTracer +using SparseDiffTools +using Test + +ext = Base.get_extension( + SparseConnectivityTracer, :SparseConnectivityTracerSparseDiffToolsExt +) +@test !isnothing(ext) + +sd = ext.ConnectivityTracerSparsityDetection() +adtype = SparseDiffTools.AutoSparseForwardDiff() + +x = rand(10) +y = zeros(9) +J1 = sparse_jacobian(adtype, sd, diff, x) +J2 = sparse_jacobian(adtype, sd, (y, x) -> y .= diff(x), y, x) +@test J1 == J2 +@test J1 isa SparseMatrixCSC +@test J2 isa SparseMatrixCSC +@test nnz(J1) == nnz(J2) == 18