From 17b2352607511de5392d3827b24e0d3d600665b4 Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 21 Aug 2024 14:47:53 +0200 Subject: [PATCH 1/5] Reorganize code --- src/SparseConnectivityTracer.jl | 4 +- src/overloads/ambiguities.jl | 35 +++++++++++++++ src/overloads/gradient_tracer.jl | 46 -------------------- src/overloads/hessian_tracer.jl | 47 --------------------- src/overloads/special_cases.jl | 25 +++++++++++ src/overloads/{overload_all.jl => utils.jl} | 0 6 files changed, 63 insertions(+), 94 deletions(-) create mode 100644 src/overloads/ambiguities.jl create mode 100644 src/overloads/special_cases.jl rename src/overloads/{overload_all.jl => utils.jl} (100%) diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 8a0e24bf..78f74d90 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -27,10 +27,12 @@ include("operators.jl") include("overloads/conversion.jl") include("overloads/gradient_tracer.jl") include("overloads/hessian_tracer.jl") +include("overloads/ambiguities.jl") +include("overloads/special_cases.jl") include("overloads/ifelse_global.jl") include("overloads/dual.jl") -include("overloads/overload_all.jl") include("overloads/arrays.jl") +include("overloads/utils.jl") include("interface.jl") include("adtypes.jl") diff --git a/src/overloads/ambiguities.jl b/src/overloads/ambiguities.jl new file mode 100644 index 00000000..cbf1f47c --- /dev/null +++ b/src/overloads/ambiguities.jl @@ -0,0 +1,35 @@ +## Special overloads to avoid ambiguity errors +for S in (Integer, Rational, Irrational{:ℯ}) + Base.:^(t::T, ::S) where {T<:GradientTracer} = t + Base.:^(::S, t::T) where {T<:GradientTracer} = t + Base.:^(t::T, ::S) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false) + Base.:^(::S, t::T) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false) + + function Base.:^(d::D, y::S) where {P,T<:GradientTracer,D<:Dual{P,T}} + x = primal(d) + t = gradient_tracer_1_to_1(tracer(d), false) + return Dual(x^y, t) + end + function Base.:^(x::S, d::D) where {P,T<:GradientTracer,D<:Dual{P,T}} + y = primal(d) + t = gradient_tracer_1_to_1(tracer(d), false) + return Dual(x^y, t) + end + + function Base.:^(d::D, y::S) where {P,T<:HessianTracer,D<:Dual{P,T}} + x = primal(d) + t = hessian_tracer_1_to_1(tracer(d), false, false) + return Dual(x^y, t) + end + function Base.:^(x::S, d::D) where {P,T<:HessianTracer,D<:Dual{P,T}} + y = primal(d) + t = hessian_tracer_1_to_1(tracer(d), false, false) + return Dual(x^y, t) + end +end + +for TT in (GradientTracer, HessianTracer) + function Base.isless(dx::D, y::AbstractFloat) where {P<:Real,T<:TT,D<:Dual{P,T}} + return isless(primal(dx), y) + end +end diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index e7077588..1e255161 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -261,49 +261,3 @@ function overload_gradient_1_to_2(M::Symbol, f) return Expr(:block, expr_gradienttracer, expr_dual) end - -## Special overloads to avoid ambiguity errors - -for S in (Integer, Rational, Irrational{:ℯ}) - Base.:^(t::T, ::S) where {T<:GradientTracer} = t - Base.:^(::S, t::T) where {T<:GradientTracer} = t - function Base.:^(d::D, y::S) where {P,T<:GradientTracer,D<:Dual{P,T}} - x = primal(d) - t = gradient_tracer_1_to_1(tracer(d), false) - return Dual(x^y, t) - end - function Base.:^(x::S, d::D) where {P,T<:GradientTracer,D<:Dual{P,T}} - y = primal(d) - t = gradient_tracer_1_to_1(tracer(d), false) - return Dual(x^y, t) - end -end - -function Base.isless(dx::D, y::AbstractFloat) where {P<:Real,T<:GradientTracer,D<:Dual{P,T}} - return isless(primal(dx), y) -end - -## Rounding -Base.round(::T, ::RoundingMode; kwargs...) where {T<:GradientTracer} = myempty(T) -function Base.round( - d::D, mode::RoundingMode; kwargs... -) where {P,T<:GradientTracer,D<:Dual{P,T}} - return round(primal(d), mode; kwargs...) # only return primal -end - -for RR in (Real, Integer, Bool) - Base.round(::Type{R}, ::T) where {R<:RR,T<:GradientTracer} = myempty(T) - function Base.round(::Type{R}, d::D) where {R<:RR,P,T<:GradientTracer,D<:Dual{P,T}} - return round(R, primal(d)) # only return primal - end -end - -## Random numbers -Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:GradientTracer} = myempty(T) -function Base.rand( - rng::AbstractRNG, ::SamplerType{D} -) where {P,T<:GradientTracer,D<:Dual{P,T}} - p = rand(rng, P) - t = myempty(T) - return Dual(p, t) -end diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index 6d4ed29c..59201cb3 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -369,50 +369,3 @@ function overload_hessian_1_to_2(M::Symbol, f) return Expr(:block, expr_hessiantracer, expr_dual) end - -## Special overloads to avoid ambiguity errors - -for S in (Integer, Rational, Irrational{:ℯ}) - Base.:^(t::T, ::S) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false) - Base.:^(::S, t::T) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false) - - function Base.:^(d::D, y::S) where {P,T<:HessianTracer,D<:Dual{P,T}} - x = primal(d) - t = hessian_tracer_1_to_1(tracer(d), false, false) - return Dual(x^y, t) - end - function Base.:^(x::S, d::D) where {P,T<:HessianTracer,D<:Dual{P,T}} - y = primal(d) - t = hessian_tracer_1_to_1(tracer(d), false, false) - return Dual(x^y, t) - end -end - -function Base.isless(dx::D, y::AbstractFloat) where {P<:Real,T<:HessianTracer,D<:Dual{P,T}} - return isless(primal(dx), y) -end - -## Rounding -Base.round(t::T, ::RoundingMode; kwargs...) where {T<:HessianTracer} = myempty(T) -function Base.round( - d::D, mode::RoundingMode; kwargs... -) where {P,T<:HessianTracer,D<:Dual{P,T}} - return round(primal(d), mode; kwargs...) # only return primal -end - -for RR in (Real, Integer, Bool) - Base.round(::Type{R}, ::T) where {R<:RR,T<:HessianTracer} = myempty(T) - function Base.round(::Type{R}, d::D) where {R<:RR,P,T<:HessianTracer,D<:Dual{P,T}} - return round(R, primal(d)) # only return primal - end -end - -## Random numbers -Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:HessianTracer} = myempty(T) -function Base.rand( - rng::AbstractRNG, ::SamplerType{D} -) where {P,T<:HessianTracer,D<:Dual{P,T}} - p = rand(rng, P) - t = myempty(T) - return Dual(p, t) -end diff --git a/src/overloads/special_cases.jl b/src/overloads/special_cases.jl new file mode 100644 index 00000000..e116c47f --- /dev/null +++ b/src/overloads/special_cases.jl @@ -0,0 +1,25 @@ +## Rounding +Base.round(t::T, ::RoundingMode; kwargs...) where {T<:AbstractTracer} = myempty(T) + +function Base.round( + d::D, mode::RoundingMode; kwargs... +) where {P,T<:AbstractTracer,D<:Dual{P,T}} + return round(primal(d), mode; kwargs...) # only return primal +end + +for RR in (Real, Integer, Bool) + Base.round(::Type{R}, ::T) where {R<:RR,T<:AbstractTracer} = myempty(T) + function Base.round(::Type{R}, d::D) where {R<:RR,P,T<:AbstractTracer,D<:Dual{P,T}} + return round(R, primal(d)) # only return primal + end +end + +## Random numbers +Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:AbstractTracer} = myempty(T) +function Base.rand( + rng::AbstractRNG, ::SamplerType{D} +) where {P,T<:AbstractTracer,D<:Dual{P,T}} + p = rand(rng, P) + t = myempty(T) + return Dual(p, t) +end diff --git a/src/overloads/overload_all.jl b/src/overloads/utils.jl similarity index 100% rename from src/overloads/overload_all.jl rename to src/overloads/utils.jl From e10fdf6dbcce05725127547a2a5a5370540e66fb Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 21 Aug 2024 14:58:21 +0200 Subject: [PATCH 2/5] Reorganize interface code --- src/SparseConnectivityTracer.jl | 4 ++-- src/{adtypes.jl => adtypes_interface.jl} | 7 +++++++ src/{interface.jl => trace_functions.jl} | 9 +++++---- 3 files changed, 14 insertions(+), 6 deletions(-) rename src/{adtypes.jl => adtypes_interface.jl} (94%) rename src/{interface.jl => trace_functions.jl} (96%) diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 78f74d90..3747d85d 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -34,8 +34,8 @@ include("overloads/dual.jl") include("overloads/arrays.jl") include("overloads/utils.jl") -include("interface.jl") -include("adtypes.jl") +include("trace_functions.jl") +include("adtypes_interface.jl") export TracerSparsityDetector export TracerLocalSparsityDetector diff --git a/src/adtypes.jl b/src/adtypes_interface.jl similarity index 94% rename from src/adtypes.jl rename to src/adtypes_interface.jl index 3453c5c6..b2b685b5 100644 --- a/src/adtypes.jl +++ b/src/adtypes_interface.jl @@ -1,3 +1,10 @@ +#= This file implements the ADTypes interface for `AbstractSparsityDetector`s =# + +const DEFAULT_GRADIENT_TRACER = GradientTracer{IndexSetGradientPattern{Int,BitSet}} +const DEFAULT_HESSIAN_TRACER = HessianTracer{ + DictHessianPattern{Int,BitSet,Dict{Int,BitSet},NotShared} +} + """ TracerSparsityDetector <: ADTypes.AbstractSparsityDetector diff --git a/src/interface.jl b/src/trace_functions.jl similarity index 96% rename from src/interface.jl rename to src/trace_functions.jl index d815f78a..7d2b52d8 100644 --- a/src/interface.jl +++ b/src/trace_functions.jl @@ -1,7 +1,8 @@ -const DEFAULT_GRADIENT_TRACER = GradientTracer{IndexSetGradientPattern{Int,BitSet}} -const DEFAULT_HESSIAN_TRACER = HessianTracer{ - DictHessianPattern{Int,BitSet,Dict{Int,BitSet},NotShared} -} +#= This file handles the actual tracing of functions: +1) creating tracers from inputs +2) evaluating the function with the created tracers +3) parsing the resulting tracers into an output matrix +=# #==================# # Enumerate inputs # From abafbaa67592c514e470e65a440f4fc6f9a605bd Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 21 Aug 2024 15:20:59 +0200 Subject: [PATCH 3/5] Separate test file for NNlib extension --- test/ext/test_NNlib.jl | 140 +++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 7 +++ test/test_gradient.jl | 74 ---------------------- test/test_hessian.jl | 39 ------------ 4 files changed, 147 insertions(+), 113 deletions(-) create mode 100644 test/ext/test_NNlib.jl diff --git a/test/ext/test_NNlib.jl b/test/ext/test_NNlib.jl new file mode 100644 index 00000000..3664b36f --- /dev/null +++ b/test/ext/test_NNlib.jl @@ -0,0 +1,140 @@ +using SparseConnectivityTracer +using SparseConnectivityTracer: GradientTracer, Dual, MissingPrimalError, trace_input +using NNlib: NNlib +using Test + +NNLIB_ACTIVATIONS_S = ( + NNlib.σ, + NNlib.celu, + NNlib.elu, + NNlib.gelu, + NNlib.hardswish, + NNlib.lisht, + NNlib.logσ, + NNlib.logcosh, + NNlib.mish, + NNlib.selu, + NNlib.softplus, + NNlib.softsign, + NNlib.swish, + NNlib.sigmoid_fast, + NNlib.tanhshrink, + NNlib.tanh_fast, +) +NNLIB_ACTIVATIONS_F = ( + NNlib.hardσ, + NNlib.hardtanh, + NNlib.leakyrelu, + NNlib.relu, + NNlib.relu6, + NNlib.softshrink, + NNlib.trelu, +) +NNLIB_ACTIVATIONS = union(NNLIB_ACTIVATIONS_S, NNLIB_ACTIVATIONS_F) + +@testset "Jacobian Global" begin + method = TracerSparsityDetector() + J(f, x) = jacobian_sparsity(f, x, method) + + @testset "$f" for f in NNLIB_ACTIVATIONS + @test J(f, 1) ≈ [1;;] + end +end + +@testset "Jacobian Local" begin + method = TracerLocalSparsityDetector() + J(f, x) = jacobian_sparsity(f, x, method) + + @test J(NNlib.relu, -1) ≈ [0;;] + @test J(NNlib.relu, 1) ≈ [1;;] + @test J(NNlib.elu, -1) ≈ [1;;] + @test J(NNlib.elu, 1) ≈ [1;;] + @test J(NNlib.celu, -1) ≈ [1;;] + @test J(NNlib.celu, 1) ≈ [1;;] + @test J(NNlib.selu, -1) ≈ [1;;] + @test J(NNlib.selu, 1) ≈ [1;;] + + @test J(NNlib.relu6, -1) ≈ [0;;] + @test J(NNlib.relu6, 1) ≈ [1;;] + @test J(NNlib.relu6, 7) ≈ [0;;] + + @test J(NNlib.trelu, 0.9) ≈ [0;;] + @test J(NNlib.trelu, 1.1) ≈ [1;;] + + @test J(NNlib.swish, -5) ≈ [1;;] + @test J(NNlib.swish, 0) ≈ [1;;] + @test J(NNlib.swish, 5) ≈ [1;;] + + @test J(NNlib.hardswish, -5) ≈ [0;;] + @test J(NNlib.hardswish, 0) ≈ [1;;] + @test J(NNlib.hardswish, 5) ≈ [1;;] + + @test J(NNlib.hardσ, -4) ≈ [0;;] + @test J(NNlib.hardσ, 0) ≈ [1;;] + @test J(NNlib.hardσ, 4) ≈ [0;;] + + @test J(NNlib.hardtanh, -2) ≈ [0;;] + @test J(NNlib.hardtanh, 0) ≈ [1;;] + @test J(NNlib.hardtanh, 2) ≈ [0;;] + + @test J(NNlib.softshrink, -1) ≈ [1;;] + @test J(NNlib.softshrink, 0) ≈ [0;;] + @test J(NNlib.softshrink, 1) ≈ [1;;] +end + +@testset "Global Hessian" begin + method = TracerSparsityDetector() + H(f, x) = hessian_sparsity(f, x, method) + + @testset "First-order differentiable" begin + @testset "$f" for f in NNLIB_ACTIVATIONS_F + @test H(f, 1) ≈ [0;;] + end + end + @testset "Second-order differentiable" begin + @testset "$f" for f in NNLIB_ACTIVATIONS_S + @test H(f, 1) ≈ [1;;] + end + end +end + +@testset "Local Hessian" begin + method = TracerLocalSparsityDetector() + H(f, x) = hessian_sparsity(f, x, method) + + @test H(NNlib.relu, -1) ≈ [0;;] + @test H(NNlib.relu, 1) ≈ [0;;] + @test H(NNlib.elu, -1) ≈ [1;;] + @test H(NNlib.elu, 1) ≈ [0;;] + @test H(NNlib.celu, -1) ≈ [1;;] + @test H(NNlib.celu, 1) ≈ [0;;] + @test H(NNlib.selu, -1) ≈ [1;;] + @test H(NNlib.selu, 1) ≈ [0;;] + + @test H(NNlib.relu6, -1) ≈ [0;;] + @test H(NNlib.relu6, 1) ≈ [0;;] + @test H(NNlib.relu6, 7) ≈ [0;;] + + @test H(NNlib.trelu, 0.9) ≈ [0;;] + @test H(NNlib.trelu, 1.1) ≈ [0;;] + + @test H(NNlib.swish, -5) ≈ [1;;] + @test H(NNlib.swish, 0) ≈ [1;;] + @test H(NNlib.swish, 5) ≈ [1;;] + + @test H(NNlib.hardswish, -5) ≈ [0;;] + @test H(NNlib.hardswish, 0) ≈ [1;;] + @test H(NNlib.hardswish, 5) ≈ [0;;] + + @test H(NNlib.hardσ, -4) ≈ [0;;] + @test H(NNlib.hardσ, 0) ≈ [0;;] + @test H(NNlib.hardσ, 4) ≈ [0;;] + + @test H(NNlib.hardtanh, -2) ≈ [0;;] + @test H(NNlib.hardtanh, 0) ≈ [0;;] + @test H(NNlib.hardtanh, 2) ≈ [0;;] + + @test H(NNlib.softshrink, -1) ≈ [0;;] + @test H(NNlib.softshrink, 0) ≈ [0;;] + @test H(NNlib.softshrink, 1) ≈ [0;;] +end diff --git a/test/runtests.jl b/test/runtests.jl index 10dc3f02..b6158490 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -91,6 +91,13 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core") end end end + if GROUP in ("Core", "All") + @info "Testing package extensions..." + @testset "NNlib" begin + @info "...NNlib" + include("ext/test_NNlib.jl") + end + end if GROUP in ("Core", "All") @info "Testing real-world examples..." diff --git a/test/test_gradient.jl b/test/test_gradient.jl index 490be484..532760e5 100644 --- a/test/test_gradient.jl +++ b/test/test_gradient.jl @@ -6,40 +6,10 @@ using Compat: Returns using Random: rand, GLOBAL_RNG using LinearAlgebra: det, dot, logdet using SpecialFunctions: erf, beta -using NNlib: NNlib # Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS include("tracers_definitions.jl") -NNLIB_ACTIVATIONS_S = ( - NNlib.σ, - NNlib.celu, - NNlib.elu, - NNlib.gelu, - NNlib.hardswish, - NNlib.lisht, - NNlib.logσ, - NNlib.logcosh, - NNlib.mish, - NNlib.selu, - NNlib.softplus, - NNlib.softsign, - NNlib.swish, - NNlib.sigmoid_fast, - NNlib.tanhshrink, - NNlib.tanh_fast, -) -NNLIB_ACTIVATIONS_F = ( - NNlib.hardσ, - NNlib.hardtanh, - NNlib.leakyrelu, - NNlib.relu, - NNlib.relu6, - NNlib.softshrink, - NNlib.trelu, -) -NNLIB_ACTIVATIONS = union(NNLIB_ACTIVATIONS_S, NNLIB_ACTIVATIONS_F) - REAL_TYPES = (Float64, Int, Bool, UInt8, Float16, Rational{Int}) # These exists to be able to quickly run tests in the REPL. @@ -133,12 +103,6 @@ J(f, x) = jacobian_sparsity(f, x, method) end end - @testset "NNlib" begin - @testset "$f" for f in NNLIB_ACTIVATIONS - @test J(f, 1) ≈ [1;;] - end - end - @testset "ifelse and comparisons" begin if VERSION >= v"1.8" @test J(x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 2 3 4]) == @@ -274,44 +238,6 @@ end @test J(x -> log(det(x)), [1.0 -1.0; 2.0 2.0]) == [1 1 1 1] @test J(x -> dot(x[1:2], x[4:5]), [0, 1, 0, 1, 0]) == [1 0 0 0 1] end - - @testset "NNlib" begin - @test J(NNlib.relu, -1) ≈ [0;;] - @test J(NNlib.relu, 1) ≈ [1;;] - @test J(NNlib.elu, -1) ≈ [1;;] - @test J(NNlib.elu, 1) ≈ [1;;] - @test J(NNlib.celu, -1) ≈ [1;;] - @test J(NNlib.celu, 1) ≈ [1;;] - @test J(NNlib.selu, -1) ≈ [1;;] - @test J(NNlib.selu, 1) ≈ [1;;] - - @test J(NNlib.relu6, -1) ≈ [0;;] - @test J(NNlib.relu6, 1) ≈ [1;;] - @test J(NNlib.relu6, 7) ≈ [0;;] - - @test J(NNlib.trelu, 0.9) ≈ [0;;] - @test J(NNlib.trelu, 1.1) ≈ [1;;] - - @test J(NNlib.swish, -5) ≈ [1;;] - @test J(NNlib.swish, 0) ≈ [1;;] - @test J(NNlib.swish, 5) ≈ [1;;] - - @test J(NNlib.hardswish, -5) ≈ [0;;] - @test J(NNlib.hardswish, 0) ≈ [1;;] - @test J(NNlib.hardswish, 5) ≈ [1;;] - - @test J(NNlib.hardσ, -4) ≈ [0;;] - @test J(NNlib.hardσ, 0) ≈ [1;;] - @test J(NNlib.hardσ, 4) ≈ [0;;] - - @test J(NNlib.hardtanh, -2) ≈ [0;;] - @test J(NNlib.hardtanh, 0) ≈ [1;;] - @test J(NNlib.hardtanh, 2) ≈ [0;;] - - @test J(NNlib.softshrink, -1) ≈ [1;;] - @test J(NNlib.softshrink, 0) ≈ [0;;] - @test J(NNlib.softshrink, 1) ≈ [1;;] - end yield() end end diff --git a/test/test_hessian.jl b/test/test_hessian.jl index 38a5fbc4..42a7ad2a 100644 --- a/test/test_hessian.jl +++ b/test/test_hessian.jl @@ -5,7 +5,6 @@ using Test using Random: rand, GLOBAL_RNG using SpecialFunctions: erf, beta -using NNlib: NNlib # Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS include("tracers_definitions.jl") @@ -398,44 +397,6 @@ end @test H(x -> rand(typeof(x)), 1) ≈ [0;;] @test H(x -> rand(GLOBAL_RNG, typeof(x)), 1) ≈ [0;;] end - - @testset "NNlib" begin - @test H(NNlib.relu, -1) ≈ [0;;] - @test H(NNlib.relu, 1) ≈ [0;;] - @test H(NNlib.elu, -1) ≈ [1;;] - @test H(NNlib.elu, 1) ≈ [0;;] - @test H(NNlib.celu, -1) ≈ [1;;] - @test H(NNlib.celu, 1) ≈ [0;;] - @test H(NNlib.selu, -1) ≈ [1;;] - @test H(NNlib.selu, 1) ≈ [0;;] - - @test H(NNlib.relu6, -1) ≈ [0;;] - @test H(NNlib.relu6, 1) ≈ [0;;] - @test H(NNlib.relu6, 7) ≈ [0;;] - - @test H(NNlib.trelu, 0.9) ≈ [0;;] - @test H(NNlib.trelu, 1.1) ≈ [0;;] - - @test H(NNlib.swish, -5) ≈ [1;;] - @test H(NNlib.swish, 0) ≈ [1;;] - @test H(NNlib.swish, 5) ≈ [1;;] - - @test H(NNlib.hardswish, -5) ≈ [0;;] - @test H(NNlib.hardswish, 0) ≈ [1;;] - @test H(NNlib.hardswish, 5) ≈ [0;;] - - @test H(NNlib.hardσ, -4) ≈ [0;;] - @test H(NNlib.hardσ, 0) ≈ [0;;] - @test H(NNlib.hardσ, 4) ≈ [0;;] - - @test H(NNlib.hardtanh, -2) ≈ [0;;] - @test H(NNlib.hardtanh, 0) ≈ [0;;] - @test H(NNlib.hardtanh, 2) ≈ [0;;] - - @test H(NNlib.softshrink, -1) ≈ [0;;] - @test H(NNlib.softshrink, 0) ≈ [0;;] - @test H(NNlib.softshrink, 1) ≈ [0;;] - end yield() end end From 28c4206d0e964beee73cd8625075ecea91c554ae Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 21 Aug 2024 15:21:17 +0200 Subject: [PATCH 4/5] Separate test file for SpecialFunctions extension --- test/ext/test_SpecialFunctions.jl | 43 +++++++++++++++++++++++++++++++ test/runtests.jl | 4 +++ test/test_gradient.jl | 6 ----- test/test_hessian.jl | 14 ---------- 4 files changed, 47 insertions(+), 20 deletions(-) create mode 100644 test/ext/test_SpecialFunctions.jl diff --git a/test/ext/test_SpecialFunctions.jl b/test/ext/test_SpecialFunctions.jl new file mode 100644 index 00000000..a92fa1d6 --- /dev/null +++ b/test/ext/test_SpecialFunctions.jl @@ -0,0 +1,43 @@ + +using SparseConnectivityTracer +using SparseConnectivityTracer: GradientTracer, Dual, MissingPrimalError, trace_input +using SpecialFunctions: erf, beta +using Test + +# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS +include("../tracers_definitions.jl") + +@testset "Jacobian Global" begin + method = TracerSparsityDetector() + J(f, x) = jacobian_sparsity(f, x, method) + + @test J(x -> erf(x[1]), rand(2)) == [1 0] + @test J(x -> beta(x[1], x[2]), rand(3)) == [1 1 0] +end + +# TODO: add tests +# @testset "Jacobian Local" begin +# method = TracerLocalSparsityDetector() +# J(f, x) = jacobian_sparsity(f, x, method) +# end + +@testset "Global Hessian" begin + method = TracerSparsityDetector() + H(f, x) = hessian_sparsity(f, x, method) + + @test H(x -> erf(x[1]), rand(2)) == [ + 1 0 + 0 0 + ] + @test H(x -> beta(x[1], x[2]), rand(3)) == [ + 1 1 0 + 1 1 0 + 0 0 0 + ] +end + +# TODO: add tests +# @testset "Local Hessian" begin +# method = TracerLocalSparsityDetector() +# H(f, x) = hessian_sparsity(f, x, method) +# end diff --git a/test/runtests.jl b/test/runtests.jl index b6158490..4606f7f4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -97,6 +97,10 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core") @info "...NNlib" include("ext/test_NNlib.jl") end + @testset "SpecialFunctions" begin + @info "...SpecialFunctions" + include("ext/test_SpecialFunctions.jl") + end end if GROUP in ("Core", "All") diff --git a/test/test_gradient.jl b/test/test_gradient.jl index 532760e5..df0979e7 100644 --- a/test/test_gradient.jl +++ b/test/test_gradient.jl @@ -5,7 +5,6 @@ using Test using Compat: Returns using Random: rand, GLOBAL_RNG using LinearAlgebra: det, dot, logdet -using SpecialFunctions: erf, beta # Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS include("tracers_definitions.jl") @@ -80,11 +79,6 @@ J(f, x) = jacobian_sparsity(f, x, method) @test J(x -> dot(x[1:2], x[4:5]), rand(5)) == [1 1 0 1 1] end - @testset "SpecialFunctions extension" begin - @test J(x -> erf(x[1]), rand(2)) == [1 0] - @test J(x -> beta(x[1], x[2]), rand(3)) == [1 1 0] - end - @testset "MissingPrimalError" begin @testset "$f" for f in ( iseven, diff --git a/test/test_hessian.jl b/test/test_hessian.jl index 42a7ad2a..66d633bf 100644 --- a/test/test_hessian.jl +++ b/test/test_hessian.jl @@ -2,9 +2,7 @@ using SparseConnectivityTracer using SparseConnectivityTracer: Dual, HessianTracer, MissingPrimalError using SparseConnectivityTracer: trace_input, create_tracers, pattern, shared using Test - using Random: rand, GLOBAL_RNG -using SpecialFunctions: erf, beta # Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS include("tracers_definitions.jl") @@ -261,18 +259,6 @@ D = Dual{Int,T} # TypeError: non-boolean (SparseConnectivityTracer.GradientTracer{BitSet}) used in boolean context @test_throws TypeError H(x -> x[1] > x[2] ? x[1]^x[2] : x[3] * x[4], rand(4)) end - - @testset "SpecialFunctions.jl" begin - @test H(x -> erf(x[1]), rand(2)) == [ - 1 0 - 0 0 - ] - @test H(x -> beta(x[1], x[2]), rand(3)) == [ - 1 1 0 - 1 1 0 - 0 0 0 - ] - end yield() end end From db6a873348c6228e5afdcb4ff41103a2cfb81db2 Mon Sep 17 00:00:00 2001 From: adrhill Date: Wed, 21 Aug 2024 15:26:42 +0200 Subject: [PATCH 5/5] Remove unused imports --- test/ext/test_NNlib.jl | 1 - test/ext/test_SpecialFunctions.jl | 1 - test/test_gradient.jl | 2 +- test/test_hessian.jl | 2 +- 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/ext/test_NNlib.jl b/test/ext/test_NNlib.jl index 3664b36f..a349737f 100644 --- a/test/ext/test_NNlib.jl +++ b/test/ext/test_NNlib.jl @@ -1,5 +1,4 @@ using SparseConnectivityTracer -using SparseConnectivityTracer: GradientTracer, Dual, MissingPrimalError, trace_input using NNlib: NNlib using Test diff --git a/test/ext/test_SpecialFunctions.jl b/test/ext/test_SpecialFunctions.jl index a92fa1d6..3108e077 100644 --- a/test/ext/test_SpecialFunctions.jl +++ b/test/ext/test_SpecialFunctions.jl @@ -1,6 +1,5 @@ using SparseConnectivityTracer -using SparseConnectivityTracer: GradientTracer, Dual, MissingPrimalError, trace_input using SpecialFunctions: erf, beta using Test diff --git a/test/test_gradient.jl b/test/test_gradient.jl index df0979e7..f8230cc6 100644 --- a/test/test_gradient.jl +++ b/test/test_gradient.jl @@ -1,5 +1,5 @@ using SparseConnectivityTracer -using SparseConnectivityTracer: GradientTracer, Dual, MissingPrimalError, trace_input +using SparseConnectivityTracer: GradientTracer, Dual, MissingPrimalError using Test using Compat: Returns diff --git a/test/test_hessian.jl b/test/test_hessian.jl index 66d633bf..7f639db8 100644 --- a/test/test_hessian.jl +++ b/test/test_hessian.jl @@ -1,6 +1,6 @@ using SparseConnectivityTracer using SparseConnectivityTracer: Dual, HessianTracer, MissingPrimalError -using SparseConnectivityTracer: trace_input, create_tracers, pattern, shared +using SparseConnectivityTracer: create_tracers, pattern, shared using Test using Random: rand, GLOBAL_RNG