From 542bbf8706892aefb9d956334d24911f18b68e23 Mon Sep 17 00:00:00 2001 From: jishnub Date: Sun, 30 May 2021 23:10:00 +0400 Subject: [PATCH] specialize axes(S, dim) to return SOneTo --- src/abstractarray.jl | 1 + test/abstractarray.jl | 12 ++++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/abstractarray.jl b/src/abstractarray.jl index 961e87a3..4bcec3db 100644 --- a/src/abstractarray.jl +++ b/src/abstractarray.jl @@ -12,6 +12,7 @@ Base.axes(s::StaticArray) = _axes(Size(s)) @pure function _axes(::Size{sizes}) where {sizes} map(SOneTo, sizes) end +Base.axes(s::StaticArray, d) = d <= ndims(s) ? _axes(Size(s))[d] : SOneTo{1}() Base.axes(rv::Adjoint{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...) Base.axes(rv::Transpose{<:Any,<:StaticVector}) = (SOneTo(1), axes(rv.parent)...) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 5800a300..b94ad2c9 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -134,6 +134,10 @@ using StaticArrays, Test, LinearAlgebra msr = reshape(ms, SOneTo(4)) msr[2] = 10 @test ms == SA[1 2; 10 4] + + s = SA[1,2]; + s2 = @inferred reshape(s, axes(s,1), axes(s,2)) + @test s2 isa StaticArray end @testset "copy" begin @@ -200,7 +204,7 @@ using StaticArrays, Test, LinearAlgebra @test @inferred(convert(AbstractArray{Float64}, diag)) isa Diagonal{Float64,SVector{2,Float64}} @test convert(AbstractArray{Float64}, diag) == diag # The following cases currently convert the SMatrix into an MMatrix, because - # the constructor in Base invokes `similar`, rather than `convert`, on the static + # the constructor in Base invokes `similar`, rather than `convert`, on the static # array. This was fixed in https://github.com/JuliaLang/julia/pull/40831; so should # work from Julia v1.8.0-DEV.55 trans = Transpose(SVector(1,2)) @@ -297,7 +301,7 @@ end @test Base.rest(x) == x a, b... = x @test b == SA[2, 3] - + x = SA[1 2; 3 4] @test Base.rest(x) == vec(x) a, b... = x @@ -306,14 +310,14 @@ end a, b... = SA[1] @test b == [] @test b isa SVector{0} - + for (Vec, Mat) in [(MVector, MMatrix), (SizedVector, SizedMatrix)] x = Vec(1, 2, 3) @test Base.rest(x) == x @test pointer(Base.rest(x)) != pointer(x) a, b... = x @test b == Vec(2, 3) - + x = Mat{2,2}(1, 2, 3, 4) @test Base.rest(x) == vec(x) @test pointer(Base.rest(x)) != pointer(x)