Skip to content

Commit

Permalink
Test output values and shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Aug 23, 2024
1 parent d3136fc commit 9d444ab
Showing 1 changed file with 30 additions and 1 deletion.
31 changes: 30 additions & 1 deletion test/ext/test_DataInterpolations.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using SparseConnectivityTracer
using SparseConnectivityTracer: DEFAULT_GRADIENT_TRACER, DEFAULT_HESSIAN_TRACER
using SparseConnectivityTracer: trace_input, Dual, primal
using DataInterpolations
using DataInterpolations: AbstractInterpolation
using Test
Expand Down Expand Up @@ -77,6 +79,33 @@ function test_interpolation(t::InterpolationTest{N}) where {N} # N ≠ 1
end
end

myprimal(x) = x
myprimal(d::Dual) = primal(d)

function test_output(t::InterpolationTest)
@testset "Output sizes and values" begin
out_ref = t.interp(tquery)
s_ref = size(out_ref)

@testset "$T" for T in (DEFAULT_GRADIENT_TRACER, DEFAULT_HESSIAN_TRACER)
t_tracer = trace_input(T, tquery)
out_tracer = t.interp(t_tracer)
s_tracer = size(out_tracer)
@test s_tracer == s_ref
end
@testset "$T" for T in (
Dual{typeof(tquery),DEFAULT_GRADIENT_TRACER},
Dual{typeof(tquery),DEFAULT_HESSIAN_TRACER},
)
t_dual = trace_input(T, tquery)
out_dual = t.interp(t_dual)
s_dual = size(out_dual)
@test s_dual == s_ref
@test myprimal.(out_dual) out_ref
end
end
end

#===========#
# Run tests #
#===========#
Expand All @@ -93,9 +122,9 @@ interpolation_tests = (
InterpolationTest(1, CubicSpline(uv, t)),
)


@testset "Test interpolations" begin
@testset "$(name(t))" for t in interpolation_tests
test_interpolation(t)
test_output(t)
end
end

0 comments on commit 9d444ab

Please sign in to comment.