diff --git a/Project.toml b/Project.toml index 72e5b14b..32339cbb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "StaticArrays" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.8.1" +version = "1.8.2" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -30,10 +30,11 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Statistics", "Unitful", "Aqua", "ChainRulesTestUtils", "ChainRulesCore"] +test = ["InteractiveUtils", "Test", "BenchmarkTools", "OffsetArrays", "Statistics", "Unitful", "Aqua", "ChainRulesTestUtils", "ChainRulesCore", "JLArrays"] diff --git a/ext/StaticArraysChainRulesCoreExt.jl b/ext/StaticArraysChainRulesCoreExt.jl index f4ff0748..5f7904aa 100644 --- a/ext/StaticArraysChainRulesCoreExt.jl +++ b/ext/StaticArraysChainRulesCoreExt.jl @@ -15,12 +15,12 @@ end # Project SArray to SArray function ProjectTo(x::SArray{S, T}) where {S, T} - return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = S) + return ProjectTo{SArray}(; element = CRC._eltype_projectto(T), axes = Size(x)) end -function (project::ProjectTo{SArray})(dx::AbstractArray{S, M}) where {S, M} - return SArray{project.axes}(dx) -end +@inline _sarray_from_array(::Size{T}, dx::AbstractArray) where {T} = SArray{Tuple{T...}}(dx) + +(project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.axes, dx) # Adjoint for SArray constructor function rrule(::Type{T}, x::Tuple) where {T <: SArray} diff --git a/test/chainrules.jl b/test/chainrules.jl index 2f01dc6b..7dbf8e8a 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,4 +1,4 @@ -using StaticArrays, ChainRulesCore, ChainRulesTestUtils, Test +using StaticArrays, ChainRulesCore, ChainRulesTestUtils, JLArrays, Test @testset "Chain Rules Integration" begin @testset "Projection" begin @@ -9,4 +9,28 @@ using StaticArrays, ChainRulesCore, ChainRulesTestUtils, Test test_rrule(SVector{4}, 1.0, 1.0, 1.0, 1.0) test_rrule(SVector{4}, 1.0, 1.0f0, 1.0, 1.0f0) end + + @testset "Type Stability" begin + x = ones(SMatrix{2, 2}) + y = ones(SVector{4}) + + @inferred ProjectTo(x) + @inferred ProjectTo(y) + @inferred ProjectTo(x)(y) + @inferred ProjectTo(y)(x) + + x = ones(SMatrix{2, 2, Float32}) + y = ones(SVector{4}) + + @inferred ProjectTo(x) + @inferred ProjectTo(x)(y) + @inferred ProjectTo(y)(x) + end + + @testset "Array of Structs Projection" begin + x = JLArray(rand(SVector{3, Float64}, 10)) + @inferred ProjectTo(x) + @inferred Union{Nothing, JLVector{SVector{3, Float64}}, DenseJLVector{SVector{3, Float64}}} ProjectTo(x)(x) + @test ProjectTo(x)(x) isa JLArray + end end