Skip to content

Commit

Permalink
Reorganize code (#179)
Browse files Browse the repository at this point in the history
* Separate test file for NNlib extension

* Separate test file for SpecialFunctions extension

* Split manual overloads into separate files

* Rename source files with outdated names
  • Loading branch information
adrhill authored Aug 21, 2024
1 parent 8f2a318 commit f0d36c6
Show file tree
Hide file tree
Showing 13 changed files with 271 additions and 235 deletions.
8 changes: 5 additions & 3 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ include("operators.jl")
include("overloads/conversion.jl")
include("overloads/gradient_tracer.jl")
include("overloads/hessian_tracer.jl")
include("overloads/ambiguities.jl")
include("overloads/special_cases.jl")
include("overloads/ifelse_global.jl")
include("overloads/dual.jl")
include("overloads/overload_all.jl")
include("overloads/arrays.jl")
include("overloads/utils.jl")

include("interface.jl")
include("adtypes.jl")
include("trace_functions.jl")
include("adtypes_interface.jl")

export TracerSparsityDetector
export TracerLocalSparsityDetector
Expand Down
7 changes: 7 additions & 0 deletions src/adtypes.jl → src/adtypes_interface.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
#= This file implements the ADTypes interface for `AbstractSparsityDetector`s =#

const DEFAULT_GRADIENT_TRACER = GradientTracer{IndexSetGradientPattern{Int,BitSet}}
const DEFAULT_HESSIAN_TRACER = HessianTracer{
DictHessianPattern{Int,BitSet,Dict{Int,BitSet},NotShared}
}

"""
TracerSparsityDetector <: ADTypes.AbstractSparsityDetector
Expand Down
35 changes: 35 additions & 0 deletions src/overloads/ambiguities.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
## Special overloads to avoid ambiguity errors
for S in (Integer, Rational, Irrational{:ℯ})
Base.:^(t::T, ::S) where {T<:GradientTracer} = t
Base.:^(::S, t::T) where {T<:GradientTracer} = t
Base.:^(t::T, ::S) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false)
Base.:^(::S, t::T) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false)

function Base.:^(d::D, y::S) where {P,T<:GradientTracer,D<:Dual{P,T}}
x = primal(d)
t = gradient_tracer_1_to_1(tracer(d), false)
return Dual(x^y, t)
end
function Base.:^(x::S, d::D) where {P,T<:GradientTracer,D<:Dual{P,T}}
y = primal(d)
t = gradient_tracer_1_to_1(tracer(d), false)
return Dual(x^y, t)
end

function Base.:^(d::D, y::S) where {P,T<:HessianTracer,D<:Dual{P,T}}
x = primal(d)
t = hessian_tracer_1_to_1(tracer(d), false, false)
return Dual(x^y, t)
end
function Base.:^(x::S, d::D) where {P,T<:HessianTracer,D<:Dual{P,T}}
y = primal(d)
t = hessian_tracer_1_to_1(tracer(d), false, false)
return Dual(x^y, t)
end
end

for TT in (GradientTracer, HessianTracer)
function Base.isless(dx::D, y::AbstractFloat) where {P<:Real,T<:TT,D<:Dual{P,T}}
return isless(primal(dx), y)
end
end
46 changes: 0 additions & 46 deletions src/overloads/gradient_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,49 +261,3 @@ function overload_gradient_1_to_2(M::Symbol, f)

return Expr(:block, expr_gradienttracer, expr_dual)
end

## Special overloads to avoid ambiguity errors

for S in (Integer, Rational, Irrational{:ℯ})
Base.:^(t::T, ::S) where {T<:GradientTracer} = t
Base.:^(::S, t::T) where {T<:GradientTracer} = t
function Base.:^(d::D, y::S) where {P,T<:GradientTracer,D<:Dual{P,T}}
x = primal(d)
t = gradient_tracer_1_to_1(tracer(d), false)
return Dual(x^y, t)
end
function Base.:^(x::S, d::D) where {P,T<:GradientTracer,D<:Dual{P,T}}
y = primal(d)
t = gradient_tracer_1_to_1(tracer(d), false)
return Dual(x^y, t)
end
end

function Base.isless(dx::D, y::AbstractFloat) where {P<:Real,T<:GradientTracer,D<:Dual{P,T}}
return isless(primal(dx), y)
end

## Rounding
Base.round(::T, ::RoundingMode; kwargs...) where {T<:GradientTracer} = myempty(T)
function Base.round(
d::D, mode::RoundingMode; kwargs...
) where {P,T<:GradientTracer,D<:Dual{P,T}}
return round(primal(d), mode; kwargs...) # only return primal
end

for RR in (Real, Integer, Bool)
Base.round(::Type{R}, ::T) where {R<:RR,T<:GradientTracer} = myempty(T)
function Base.round(::Type{R}, d::D) where {R<:RR,P,T<:GradientTracer,D<:Dual{P,T}}
return round(R, primal(d)) # only return primal
end
end

## Random numbers
Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:GradientTracer} = myempty(T)
function Base.rand(
rng::AbstractRNG, ::SamplerType{D}
) where {P,T<:GradientTracer,D<:Dual{P,T}}
p = rand(rng, P)
t = myempty(T)
return Dual(p, t)
end
47 changes: 0 additions & 47 deletions src/overloads/hessian_tracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,50 +369,3 @@ function overload_hessian_1_to_2(M::Symbol, f)

return Expr(:block, expr_hessiantracer, expr_dual)
end

## Special overloads to avoid ambiguity errors

for S in (Integer, Rational, Irrational{:ℯ})
Base.:^(t::T, ::S) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false)
Base.:^(::S, t::T) where {T<:HessianTracer} = hessian_tracer_1_to_1(t, false, false)

function Base.:^(d::D, y::S) where {P,T<:HessianTracer,D<:Dual{P,T}}
x = primal(d)
t = hessian_tracer_1_to_1(tracer(d), false, false)
return Dual(x^y, t)
end
function Base.:^(x::S, d::D) where {P,T<:HessianTracer,D<:Dual{P,T}}
y = primal(d)
t = hessian_tracer_1_to_1(tracer(d), false, false)
return Dual(x^y, t)
end
end

function Base.isless(dx::D, y::AbstractFloat) where {P<:Real,T<:HessianTracer,D<:Dual{P,T}}
return isless(primal(dx), y)
end

## Rounding
Base.round(t::T, ::RoundingMode; kwargs...) where {T<:HessianTracer} = myempty(T)
function Base.round(
d::D, mode::RoundingMode; kwargs...
) where {P,T<:HessianTracer,D<:Dual{P,T}}
return round(primal(d), mode; kwargs...) # only return primal
end

for RR in (Real, Integer, Bool)
Base.round(::Type{R}, ::T) where {R<:RR,T<:HessianTracer} = myempty(T)
function Base.round(::Type{R}, d::D) where {R<:RR,P,T<:HessianTracer,D<:Dual{P,T}}
return round(R, primal(d)) # only return primal
end
end

## Random numbers
Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:HessianTracer} = myempty(T)
function Base.rand(
rng::AbstractRNG, ::SamplerType{D}
) where {P,T<:HessianTracer,D<:Dual{P,T}}
p = rand(rng, P)
t = myempty(T)
return Dual(p, t)
end
25 changes: 25 additions & 0 deletions src/overloads/special_cases.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
## Rounding
Base.round(t::T, ::RoundingMode; kwargs...) where {T<:AbstractTracer} = myempty(T)

function Base.round(
d::D, mode::RoundingMode; kwargs...
) where {P,T<:AbstractTracer,D<:Dual{P,T}}
return round(primal(d), mode; kwargs...) # only return primal
end

for RR in (Real, Integer, Bool)
Base.round(::Type{R}, ::T) where {R<:RR,T<:AbstractTracer} = myempty(T)
function Base.round(::Type{R}, d::D) where {R<:RR,P,T<:AbstractTracer,D<:Dual{P,T}}
return round(R, primal(d)) # only return primal
end
end

## Random numbers
Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:AbstractTracer} = myempty(T)
function Base.rand(
rng::AbstractRNG, ::SamplerType{D}
) where {P,T<:AbstractTracer,D<:Dual{P,T}}
p = rand(rng, P)
t = myempty(T)
return Dual(p, t)
end
File renamed without changes.
9 changes: 5 additions & 4 deletions src/interface.jl → src/trace_functions.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
const DEFAULT_GRADIENT_TRACER = GradientTracer{IndexSetGradientPattern{Int,BitSet}}
const DEFAULT_HESSIAN_TRACER = HessianTracer{
DictHessianPattern{Int,BitSet,Dict{Int,BitSet},NotShared}
}
#= This file handles the actual tracing of functions:
1) creating tracers from inputs
2) evaluating the function with the created tracers
3) parsing the resulting tracers into an output matrix
=#

#==================#
# Enumerate inputs #
Expand Down
139 changes: 139 additions & 0 deletions test/ext/test_NNlib.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
using SparseConnectivityTracer
using NNlib: NNlib
using Test

NNLIB_ACTIVATIONS_S = (
NNlib.σ,
NNlib.celu,
NNlib.elu,
NNlib.gelu,
NNlib.hardswish,
NNlib.lisht,
NNlib.logσ,
NNlib.logcosh,
NNlib.mish,
NNlib.selu,
NNlib.softplus,
NNlib.softsign,
NNlib.swish,
NNlib.sigmoid_fast,
NNlib.tanhshrink,
NNlib.tanh_fast,
)
NNLIB_ACTIVATIONS_F = (
NNlib.hardσ,
NNlib.hardtanh,
NNlib.leakyrelu,
NNlib.relu,
NNlib.relu6,
NNlib.softshrink,
NNlib.trelu,
)
NNLIB_ACTIVATIONS = union(NNLIB_ACTIVATIONS_S, NNLIB_ACTIVATIONS_F)

@testset "Jacobian Global" begin
method = TracerSparsityDetector()
J(f, x) = jacobian_sparsity(f, x, method)

@testset "$f" for f in NNLIB_ACTIVATIONS
@test J(f, 1) [1;;]
end
end

@testset "Jacobian Local" begin
method = TracerLocalSparsityDetector()
J(f, x) = jacobian_sparsity(f, x, method)

@test J(NNlib.relu, -1) [0;;]
@test J(NNlib.relu, 1) [1;;]
@test J(NNlib.elu, -1) [1;;]
@test J(NNlib.elu, 1) [1;;]
@test J(NNlib.celu, -1) [1;;]
@test J(NNlib.celu, 1) [1;;]
@test J(NNlib.selu, -1) [1;;]
@test J(NNlib.selu, 1) [1;;]

@test J(NNlib.relu6, -1) [0;;]
@test J(NNlib.relu6, 1) [1;;]
@test J(NNlib.relu6, 7) [0;;]

@test J(NNlib.trelu, 0.9) [0;;]
@test J(NNlib.trelu, 1.1) [1;;]

@test J(NNlib.swish, -5) [1;;]
@test J(NNlib.swish, 0) [1;;]
@test J(NNlib.swish, 5) [1;;]

@test J(NNlib.hardswish, -5) [0;;]
@test J(NNlib.hardswish, 0) [1;;]
@test J(NNlib.hardswish, 5) [1;;]

@test J(NNlib.hardσ, -4) [0;;]
@test J(NNlib.hardσ, 0) [1;;]
@test J(NNlib.hardσ, 4) [0;;]

@test J(NNlib.hardtanh, -2) [0;;]
@test J(NNlib.hardtanh, 0) [1;;]
@test J(NNlib.hardtanh, 2) [0;;]

@test J(NNlib.softshrink, -1) [1;;]
@test J(NNlib.softshrink, 0) [0;;]
@test J(NNlib.softshrink, 1) [1;;]
end

@testset "Global Hessian" begin
method = TracerSparsityDetector()
H(f, x) = hessian_sparsity(f, x, method)

@testset "First-order differentiable" begin
@testset "$f" for f in NNLIB_ACTIVATIONS_F
@test H(f, 1) [0;;]
end
end
@testset "Second-order differentiable" begin
@testset "$f" for f in NNLIB_ACTIVATIONS_S
@test H(f, 1) [1;;]
end
end
end

@testset "Local Hessian" begin
method = TracerLocalSparsityDetector()
H(f, x) = hessian_sparsity(f, x, method)

@test H(NNlib.relu, -1) [0;;]
@test H(NNlib.relu, 1) [0;;]
@test H(NNlib.elu, -1) [1;;]
@test H(NNlib.elu, 1) [0;;]
@test H(NNlib.celu, -1) [1;;]
@test H(NNlib.celu, 1) [0;;]
@test H(NNlib.selu, -1) [1;;]
@test H(NNlib.selu, 1) [0;;]

@test H(NNlib.relu6, -1) [0;;]
@test H(NNlib.relu6, 1) [0;;]
@test H(NNlib.relu6, 7) [0;;]

@test H(NNlib.trelu, 0.9) [0;;]
@test H(NNlib.trelu, 1.1) [0;;]

@test H(NNlib.swish, -5) [1;;]
@test H(NNlib.swish, 0) [1;;]
@test H(NNlib.swish, 5) [1;;]

@test H(NNlib.hardswish, -5) [0;;]
@test H(NNlib.hardswish, 0) [1;;]
@test H(NNlib.hardswish, 5) [0;;]

@test H(NNlib.hardσ, -4) [0;;]
@test H(NNlib.hardσ, 0) [0;;]
@test H(NNlib.hardσ, 4) [0;;]

@test H(NNlib.hardtanh, -2) [0;;]
@test H(NNlib.hardtanh, 0) [0;;]
@test H(NNlib.hardtanh, 2) [0;;]

@test H(NNlib.softshrink, -1) [0;;]
@test H(NNlib.softshrink, 0) [0;;]
@test H(NNlib.softshrink, 1) [0;;]
end
42 changes: 42 additions & 0 deletions test/ext/test_SpecialFunctions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@

using SparseConnectivityTracer
using SpecialFunctions: erf, beta
using Test

# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("../tracers_definitions.jl")

@testset "Jacobian Global" begin
method = TracerSparsityDetector()
J(f, x) = jacobian_sparsity(f, x, method)

@test J(x -> erf(x[1]), rand(2)) == [1 0]
@test J(x -> beta(x[1], x[2]), rand(3)) == [1 1 0]
end

# TODO: add tests
# @testset "Jacobian Local" begin
# method = TracerLocalSparsityDetector()
# J(f, x) = jacobian_sparsity(f, x, method)
# end

@testset "Global Hessian" begin
method = TracerSparsityDetector()
H(f, x) = hessian_sparsity(f, x, method)

@test H(x -> erf(x[1]), rand(2)) == [
1 0
0 0
]
@test H(x -> beta(x[1], x[2]), rand(3)) == [
1 1 0
1 1 0
0 0 0
]
end

# TODO: add tests
# @testset "Local Hessian" begin
# method = TracerLocalSparsityDetector()
# H(f, x) = hessian_sparsity(f, x, method)
# end
Loading

0 comments on commit f0d36c6

Please sign in to comment.