From 1c282d90b9466c972bb36c5676c36104d43c395c Mon Sep 17 00:00:00 2001 From: adrhill Date: Thu, 22 Aug 2024 18:25:23 +0200 Subject: [PATCH] Move tests to `test/ext` folder --- test/ext/test_DataInterpolations.jl | 72 +++++++++++++++++++++++++++++ test/runtests.jl | 12 ++--- test/test_gradient.jl | 16 ------- 3 files changed, 77 insertions(+), 23 deletions(-) create mode 100644 test/ext/test_DataInterpolations.jl diff --git a/test/ext/test_DataInterpolations.jl b/test/ext/test_DataInterpolations.jl new file mode 100644 index 00000000..728bbebb --- /dev/null +++ b/test/ext/test_DataInterpolations.jl @@ -0,0 +1,72 @@ + +using SparseConnectivityTracer +using DataInterpolations +using Test + +# Sample of interpolation types +interpolations_z = (ConstantInterpolation,) +interpolations_f = (LinearInterpolation,) +interpolations_s = (QuadraticInterpolation,) + +u = [1.0, 2.0, 5.0] +t = [0.0, 1.0, 3.0] + +@testset "Jacobian Global" begin + method = TracerSparsityDetector() + J(f, x) = jacobian_sparsity(f, x, method) + + @testset "Non-differentiable" begin + @testset "$TI" for TI in interpolations_z + interpolant = TI(u, t) + @test J(interpolant, 2.0) ≈ [0;;] + end + end + @testset "First-order differentiable" begin + @testset "$TI" for TI in interpolations_f + interpolant = TI(u, t) + @test J(interpolant, 2.0) ≈ [1;;] + end + end + @testset "Second-order differentiable" begin + @testset "$TI" for TI in interpolations_s + interpolant = TI(u, t) + @test J(interpolant, 2.0) ≈ [1;;] + end + end +end + +# TODO: add tests +# @testset "Jacobian Local" begin +# method = TracerLocalSparsityDetector() +# J(f, x) = jacobian_sparsity(f, x, method) +# end + +@testset "Global Hessian" begin + method = TracerSparsityDetector() + H(f, x) = hessian_sparsity(f, x, method) + + @testset "Non-differentiable" begin + @testset "$TI" for TI in interpolations_z + interpolant = TI(u, t) + @test J(interpolant, 2.0) ≈ [0;;] + end + end + @testset "First-order differentiable" begin + @testset "$TI" for TI in interpolations_f + interpolant = TI(u, t) + @test J(interpolant, 2.0) ≈ [0;;] + end + end + @testset "Second-order differentiable" begin + @testset "$TI" for TI in interpolations_s + interpolant = TI(u, t) + @test J(interpolant, 2.0) ≈ [1;;] + end + end +end + +# TODO: add tests +# @testset "Local Hessian" begin +# method = TracerLocalSparsityDetector() +# H(f, x) = hessian_sparsity(f, x, method) +# end diff --git a/test/runtests.jl b/test/runtests.jl index 4606f7f4..b5415eed 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -93,13 +93,11 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core") end if GROUP in ("Core", "All") @info "Testing package extensions..." - @testset "NNlib" begin - @info "...NNlib" - include("ext/test_NNlib.jl") - end - @testset "SpecialFunctions" begin - @info "...SpecialFunctions" - include("ext/test_SpecialFunctions.jl") + for ext in (:NNlib, :SpecialFunctions, :DataInterpolations) + @testset "$ext" begin + @info "...$ext" + include("ext/test_$ext.jl") + end end end diff --git a/test/test_gradient.jl b/test/test_gradient.jl index 3f0acfdd..f8230cc6 100644 --- a/test/test_gradient.jl +++ b/test/test_gradient.jl @@ -9,9 +9,6 @@ using LinearAlgebra: det, dot, logdet # Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS include("tracers_definitions.jl") -# Sample of interpolation types -interpolation_types = [ConstantInterpolation, LinearInterpolation, QuadraticInterpolation] - REAL_TYPES = (Float64, Int, Bool, UInt8, Float16, Rational{Int}) # These exists to be able to quickly run tests in the REPL. @@ -100,19 +97,6 @@ J(f, x) = jacobian_sparsity(f, x, method) end end - @testset "DataInterpolations" begin - u = [1.0, 2.0, 5.0] - t = [0.0, 1.0, 3.0] - @testset "$interpolation_type" for interpolation_type in interpolation_types - A = interpolation_type(u, t) - if interpolation_type == ConstantInterpolation - @test J(A, 2.0) == [0] - else - @test J(A, 2.0) == [1] - end - end - end - @testset "ifelse and comparisons" begin if VERSION >= v"1.8" @test J(x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 2 3 4]) ==