Skip to content

Commit

Permalink
Move tests to test/ext folder
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Aug 22, 2024
1 parent 598fe2e commit 1c282d9
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 23 deletions.
72 changes: 72 additions & 0 deletions test/ext/test_DataInterpolations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@

using SparseConnectivityTracer
using DataInterpolations
using Test

# Sample of interpolation types
interpolations_z = (ConstantInterpolation,)
interpolations_f = (LinearInterpolation,)
interpolations_s = (QuadraticInterpolation,)

u = [1.0, 2.0, 5.0]
t = [0.0, 1.0, 3.0]

@testset "Jacobian Global" begin
method = TracerSparsityDetector()
J(f, x) = jacobian_sparsity(f, x, method)

@testset "Non-differentiable" begin
@testset "$TI" for TI in interpolations_z
interpolant = TI(u, t)
@test J(interpolant, 2.0) [0;;]
end
end
@testset "First-order differentiable" begin
@testset "$TI" for TI in interpolations_f
interpolant = TI(u, t)
@test J(interpolant, 2.0) [1;;]
end
end
@testset "Second-order differentiable" begin
@testset "$TI" for TI in interpolations_s
interpolant = TI(u, t)
@test J(interpolant, 2.0) [1;;]
end
end
end

# TODO: add tests
# @testset "Jacobian Local" begin
# method = TracerLocalSparsityDetector()
# J(f, x) = jacobian_sparsity(f, x, method)
# end

@testset "Global Hessian" begin
method = TracerSparsityDetector()
H(f, x) = hessian_sparsity(f, x, method)

@testset "Non-differentiable" begin
@testset "$TI" for TI in interpolations_z
interpolant = TI(u, t)
@test J(interpolant, 2.0) [0;;]
end
end
@testset "First-order differentiable" begin
@testset "$TI" for TI in interpolations_f
interpolant = TI(u, t)
@test J(interpolant, 2.0) [0;;]
end
end
@testset "Second-order differentiable" begin
@testset "$TI" for TI in interpolations_s
interpolant = TI(u, t)
@test J(interpolant, 2.0) [1;;]
end
end
end

# TODO: add tests
# @testset "Local Hessian" begin
# method = TracerLocalSparsityDetector()
# H(f, x) = hessian_sparsity(f, x, method)
# end
12 changes: 5 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,11 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core")
end
if GROUP in ("Core", "All")
@info "Testing package extensions..."
@testset "NNlib" begin
@info "...NNlib"
include("ext/test_NNlib.jl")
end
@testset "SpecialFunctions" begin
@info "...SpecialFunctions"
include("ext/test_SpecialFunctions.jl")
for ext in (:NNlib, :SpecialFunctions, :DataInterpolations)
@testset "$ext" begin
@info "...$ext"
include("ext/test_$ext.jl")
end
end
end

Expand Down
16 changes: 0 additions & 16 deletions test/test_gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ using LinearAlgebra: det, dot, logdet
# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("tracers_definitions.jl")

# Sample of interpolation types
interpolation_types = [ConstantInterpolation, LinearInterpolation, QuadraticInterpolation]

REAL_TYPES = (Float64, Int, Bool, UInt8, Float16, Rational{Int})

# These exists to be able to quickly run tests in the REPL.
Expand Down Expand Up @@ -100,19 +97,6 @@ J(f, x) = jacobian_sparsity(f, x, method)
end
end

@testset "DataInterpolations" begin
u = [1.0, 2.0, 5.0]
t = [0.0, 1.0, 3.0]
@testset "$interpolation_type" for interpolation_type in interpolation_types
A = interpolation_type(u, t)
if interpolation_type == ConstantInterpolation
@test J(A, 2.0) == [0]
else
@test J(A, 2.0) == [1]
end
end
end

@testset "ifelse and comparisons" begin
if VERSION >= v"1.8"
@test J(x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 2 3 4]) ==
Expand Down

0 comments on commit 1c282d9

Please sign in to comment.