diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 00000000..cc91afe1 --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1,6 @@ +style = "blue" +align_assignment = true +align_struct_field = true +align_conditional = true +align_pair_arrow = true +align_matrix = true diff --git a/docs/make.jl b/docs/make.jl index 096d8021..b21b6e94 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,23 +1,23 @@ using SparseConnectivityTracer using Documenter -DocMeta.setdocmeta!(SparseConnectivityTracer, :DocTestSetup, :(using SparseConnectivityTracer); recursive=true) +DocMeta.setdocmeta!( + SparseConnectivityTracer, + :DocTestSetup, + :(using SparseConnectivityTracer); + recursive=true, +) makedocs(; modules=[SparseConnectivityTracer], authors="Adrian Hill ", sitename="SparseConnectivityTracer.jl", format=Documenter.HTML(; - canonical="https://adrhill.github.io/SparseConnectivityTracer.jl", - edit_link="main", - assets=String[], + canonical = "https://adrhill.github.io/SparseConnectivityTracer.jl", + edit_link = "main", + assets = String[], ), - pages=[ - "Home" => "index.md", - ], + pages=["Home" => "index.md"], ) -deploydocs(; - repo="github.com/adrhill/SparseConnectivityTracer.jl", - devbranch="main", -) +deploydocs(; repo="github.com/adrhill/SparseConnectivityTracer.jl", devbranch="main") diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 02041487..5685d42a 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -1,5 +1,68 @@ module SparseConnectivityTracer -# Write your package code here. +# Input connectivity tracer +struct Tracer <: Number + inputs::Set{UInt64} # indices of connected, enumerated inputs +end + +Tracer(i::Integer) = Tracer(Set{UInt64}(i)) +Tracer(a::Tracer, b::Tracer) = Tracer(a.inputs ∪ b.inputs) + +# Enumerate inputs +inputtrace(x) = inputtrace(x, 1) +inputtrace(::Number, i) = Tracer(i) +function inputtrace(x::AbstractArray, i) + indices = (i - 1) .+ reshape(1:length(x), size(x)) + return Tracer.(indices) +end + +include("ops.jl") + +# Extent core operators +for fn in (:+, :-, :*, :/, :^) + @eval Base.$fn(a::Tracer, b::Tracer) = Tracer(a, b) + for T in (:Number,) + @eval Base.$fn(t::Tracer, ::$T) = t + @eval Base.$fn(::$T, t::Tracer) = t + end +end + +Base.:^(a::Tracer, b::Tracer) = Tracer(a, b) +for T in (:Number, :Integer, :Rational) + @eval Base.:^(t::Tracer, ::$T) = t + @eval Base.:^(::$T, t::Tracer) = t +end +Base.:^(t::Tracer, ::Irrational{:ℯ}) = t +Base.:^(::Irrational{:ℯ}, t::Tracer) = t + +# Two-argument functions +for fn in (:div, :fld, :cld) + @eval Base.$fn(a::Tracer, b::Tracer) = Tracer(a, b) + @eval Base.$fn(t::Tracer, ::Number) = t + @eval Base.$fn(::Number, t::Tracer) = t +end + +# Single-argument functions +for fn in scalar_operations + @eval Base.$fn(t::Tracer) = t +end + +function connectivity(f, x) + xt = inputtrace(x) + yt = f(xt) + n, m = length(xt), length(yt) + + # Construct connectivity matrix of size (ouput_dim, input_dim) + C = BitArray(undef, m, n) + for i in axes(C, 1) + tracer = yt[i] + for j in axes(C, 2) + C[i, j] = j ∈ tracer.inputs + end + end + return C +end + +export connectivity end diff --git a/src/enumerate.jl b/src/enumerate.jl new file mode 100644 index 00000000..763c0cda --- /dev/null +++ b/src/enumerate.jl @@ -0,0 +1,7 @@ + +enumerate_tracers(x) = enumerate_tracers(x, 1) +enumerate_tracers(::Number, i) = Tracer(i) +function enumerate_tracers(x::AbstractArray, i) + indices = (i - 1) .+ reshape(1:length(x), size(x)) + return Tracer.(indices) +end diff --git a/src/ops.jl b/src/ops.jl new file mode 100644 index 00000000..87924afa --- /dev/null +++ b/src/ops.jl @@ -0,0 +1,20 @@ +#! format: off +scalar_operations = ( + :exp2, :deg2rad, :rad2deg, + :sincos, :sincospi, + :cos, :cosd, :cosh, :cospi, :cosc, + :sin, :sind, :sinh, :sinpi, :sinc, + :tan, :tand, :tanh, + :csc, :cscd, :csch, + :sec, :secd, :sech, + :cot, :cotd, :coth, + :acos, :acosd, :acosh, + :asin, :asind, :asinh, + :atan, :atand, :atanh, + :asec, :asech, + :acsc, :acsch, + :acot, :acoth, + :exp, :expm1, :exp10, + :frexp, :ldexp, +) +#! format: on diff --git a/test/Project.toml b/test/Project.toml index 737dfdaf..75ff838e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,7 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index 9eff7a85..3b50e4cc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,14 +1,29 @@ using SparseConnectivityTracer using Test +using JuliaFormatter using Aqua using JET +using LinearAlgebra +using Random + @testset "SparseConnectivityTracer.jl" begin - @testset "Code quality (Aqua.jl)" begin + @testset "Code formatting" begin + @test JuliaFormatter.format( + SparseConnectivityTracer; verbose=false, overwrite=false + ) + end + @testset "Aqua.jl tests" begin Aqua.test_all(SparseConnectivityTracer) end - @testset "Code linting (JET.jl)" begin - JET.test_package(SparseConnectivityTracer; target_defined_modules = true) + @testset "JET tests" begin + JET.test_package(SparseConnectivityTracer; target_defined_modules=true) + end + + @testset "Connectivity" begin + f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])] + @test connectivity(f, rand(3)) == BitMatrix([1 0 0; 1 1 0; 0 0 1]) + + @test connectivity(identity, rand()) == BitMatrix([1;;]) end - # Write your tests here. end