From 398bc442a2f00149b14b74f8c1ec6b21439d715d Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Wed, 31 Jul 2024 10:09:12 -0400 Subject: [PATCH] Add linear index support in get_struct --- src/DataLayouts/struct.jl | 84 ++++++++++++ src/DataLayouts/to_linear_index.jl | 39 ++++++ test/DataLayouts/unit_linear_indexing.jl | 159 +++++++++++++++++++++++ 3 files changed, 282 insertions(+) create mode 100644 src/DataLayouts/to_linear_index.jl create mode 100644 test/DataLayouts/unit_linear_indexing.jl diff --git a/src/DataLayouts/struct.jl b/src/DataLayouts/struct.jl index c20b580734..7e4c1b1fc3 100644 --- a/src/DataLayouts/struct.jl +++ b/src/DataLayouts/struct.jl @@ -218,6 +218,90 @@ Base.@propagate_inbounds function get_struct( @inbounds return array[start_index] end + +abstract type _Size end +struct DynamicSize <: _Size end +struct StaticSize{S_array, FD} <: _Size + function StaticSize{S, FD}() where {S, FD} + new{S::Tuple{Vararg{Int}}, FD}() + end +end + +Base.@pure StaticSize(s::Tuple{Vararg{Int}}, FD) = StaticSize{s, FD}() + +# Some @pure convenience functions for `StaticSize` +s_field_dim_1(::Type{StaticSize{S, FD}}) where {S, FD} = Tuple(map(j-> j == FD ? 1 : S[j], 1:length(S))) +s_field_dim_1(::StaticSize{S, FD}) where {S, FD} = Tuple(map(j-> j == FD ? 1 : S[j], 1:length(S))) + +Base.@pure get(::Type{StaticSize{S}}) where {S} = S +Base.@pure get(::StaticSize{S}) where {S} = S +Base.@pure Base.getindex(::StaticSize{S}, i::Int) where {S} = i <= length(S) ? S[i] : 1 +Base.@pure Base.ndims(::StaticSize{S}) where {S} = length(S) +Base.@pure Base.ndims(::Type{StaticSize{S}}) where {S} = length(S) +Base.@pure Base.length(::StaticSize{S}) where {S} = prod(S) + +Base.@propagate_inbounds cart_ind(n::NTuple, i::Integer) = + @inbounds CartesianIndices(map(x -> Base.OneTo(x), n))[i] +Base.@propagate_inbounds linear_ind(n::NTuple) = + @inbounds LinearIndices(map(x -> Base.OneTo(x), n)) + +include("to_linear_index.jl") # TODO: delete if not needed + +@inline function offset_index( + start_index::Integer, + ::Val{D}, + field_offset, + ss::StaticSize{SS}; +) where {D, SS} + # TODO: compute this offset directly without going through CartesianIndex + SS1 = s_field_dim_1(typeof(ss)) + ci = cart_ind(SS1, start_index) + ci_poff = CartesianIndex(ntuple(n -> n == D ? ci[n] + field_offset : ci[n], ndims(ss))) + return linear_ind(SS)[ci_poff] +end + +Base.@propagate_inbounds @generated function get_struct_linear( + array::AbstractArray{T}, + ::Type{S}, + ::Val{D}, + start_index::Integer, + ss::StaticSize; +) where {T, S, D} + tup = :(()) + for i in 1:fieldcount(S) + push!( + tup.args, + :(get_struct_linear( + array, + fieldtype(S, $i), + Val($D), + offset_index( + start_index, + Val($D), + $(fieldtypeoffset(T, S, Val(i))), + ss + ), + ss + )), + ) + end + return quote + Base.@_propagate_inbounds_meta + @inbounds bypass_constructor(S, $tup) + end +end + +# recursion base case: hit array type is the same as the struct leaf type +Base.@propagate_inbounds function get_struct_linear( + array::AbstractArray{S}, + ::Type{S}, + ::Val{D}, + start_index::Integer, + us::StaticSize +) where {S, D} + @inbounds return array[start_index] +end + """ set_struct!(array, val::S, Val(D), start_index) diff --git a/src/DataLayouts/to_linear_index.jl b/src/DataLayouts/to_linear_index.jl new file mode 100644 index 0000000000..7cc84eb8c6 --- /dev/null +++ b/src/DataLayouts/to_linear_index.jl @@ -0,0 +1,39 @@ +_to_linear_index(A::AbstractArray, li, ci) = _to_linear_index(A, Base.to_indices(li, (ci,))...) +_to_linear_index(A::AbstractArray, I::Integer...) = (@inline; _sub2ind(A, I...)) + +function _sub2ind(A::AbstractArray, I...) + @inline + _sub2ind(axes(A), I...) +end + +# 0-dimensional arrays and indexing with [] +_sub2ind(::Tuple{}) = 1 +_sub2ind(::Base.DimsInteger) = 1 +# _sub2ind(::Indices) = 1 +_sub2ind(::Tuple{}, I::Integer...) = (@inline; _sub2ind_recurse((), 1, 1, I...)) + +# Generic cases +_sub2ind(dims::Base.DimsInteger, I::Integer...) = (@inline; _sub2ind_recurse(dims, 1, 1, I...)) +_sub2ind(inds::Base.Indices, I::Integer...) = (@inline; _sub2ind_recurse(inds, 1, 1, I...)) +# In 1d, there's a question of whether we're doing cartesian indexing +# or linear indexing. Support only the former. +_sub2ind(inds::Base.Indices{1}, I::Integer...) = + throw(ArgumentError("Linear indexing is not defined for one-dimensional arrays")) +_sub2ind(inds::Tuple{Base.OneTo}, I::Integer...) = (@inline; _sub2ind_recurse(inds, 1, 1, I...)) # only OneTo is safe +_sub2ind(inds::Tuple{Base.OneTo}, i::Integer) = i + +_sub2ind_recurse(::Any, L, ind) = ind +function _sub2ind_recurse(::Tuple{}, L, ind, i::Integer, I::Integer...) + @inline + _sub2ind_recurse((), L, ind+(i-1)*L, I...) +end +function _sub2ind_recurse(inds, L, ind, i::Integer, I::Integer...) + @inline + r1 = inds[1] + _sub2ind_recurse(Base.tail(inds), nextL(L, r1), ind+offsetin(i, r1)*L, I...) +end + +nextL(L, l::Integer) = L*l +nextL(L, r::AbstractUnitRange) = L*length(r) +offsetin(i, l::Integer) = i-1 +offsetin(i, r::AbstractUnitRange) = i-first(r) diff --git a/test/DataLayouts/unit_linear_indexing.jl b/test/DataLayouts/unit_linear_indexing.jl new file mode 100644 index 0000000000..0c665c6501 --- /dev/null +++ b/test/DataLayouts/unit_linear_indexing.jl @@ -0,0 +1,159 @@ +#= +julia --check-bounds=yes --project +using Revise; include(joinpath("test", "DataLayouts", "unit_linear_indexing.jl")) +=# +using Test +using ClimaCore.DataLayouts +using ClimaCore.DataLayouts: get_struct_linear +import ClimaCore.Geometry +# import ClimaComms +using StaticArrays +# ClimaComms.@import_required_backends +import Random +Random.seed!(1234) + +offset_indices(::Type{FT}, + ::Type{S}, + ::Val{D}, + start_index::Integer, + ss::DataLayouts.StaticSize +) where {FT, S, D} = map(i-> DL.offset_index( + start_index, + Val(D), + DL.fieldtypeoffset(FT, S, Val(i)), + ss + ), 1:fieldcount(S)) +import ClimaCore.DataLayouts as DL +field_dim_to_one(s, dim) = Tuple(map(j-> j == dim ? 1 : s[j], 1:length(s))) + +Base.@propagate_inbounds cart_ind(n::NTuple, i::Integer) = + @inbounds CartesianIndices(map(x -> Base.OneTo(x), n))[i] +Base.@propagate_inbounds linear_ind(n::NTuple, ci::CartesianIndex) = + @inbounds LinearIndices(map(x -> Base.OneTo(x), n))[ci] +Base.@propagate_inbounds linear_ind(n::NTuple, loc::NTuple) = + linear_ind(n, CartesianIndex(loc)) + +function debug_get_struct_linear(args...; expect_test_throws=false) + if expect_test_throws + get_struct_linear(args...) + else + try + get_struct_linear(args...) + catch + get_struct_linear(args...) + end + end +end + +function one_to_n(a::Array) + for i in 1:length(a) + a[i] = i + end + return a +end +one_to_n(s::Tuple, ::Type{FT}) where {FT} = one_to_n(zeros(FT, s...)) +ncomponents(::Type{FT}, ::Type{S}) where {FT, S} = div(sizeof(S), sizeof(FT)) + +struct Foo{T} + x::T + y::T +end + +Base.zero(::Type{Foo{T}}) where {T} = Foo{T}(0, 0) + +@testset "get_struct - IFH indexing (float)" begin + FT = Float64 + S = FT + s_array = (3, 1, 4) + @test ncomponents(FT, S) == 1 + a = one_to_n(s_array, FT) + ss = DataLayouts.StaticSize(s_array, 2) + @test debug_get_struct_linear(a, S, Val(2), 1, ss) == 1.0 + @test debug_get_struct_linear(a, S, Val(2), 2, ss) == 2.0 + @test debug_get_struct_linear(a, S, Val(2), 3, ss) == 3.0 + @test debug_get_struct_linear(a, S, Val(2), 4, ss) == 4.0 + @test debug_get_struct_linear(a, S, Val(2), 5, ss) == 5.0 + @test debug_get_struct_linear(a, S, Val(2), 6, ss) == 6.0 + @test debug_get_struct_linear(a, S, Val(2), 7, ss) == 7.0 + @test debug_get_struct_linear(a, S, Val(2), 8, ss) == 8.0 + @test debug_get_struct_linear(a, S, Val(2), 9, ss) == 9.0 + @test debug_get_struct_linear(a, S, Val(2), 10, ss) == 10.0 + @test debug_get_struct_linear(a, S, Val(2), 11, ss) == 11.0 + @test debug_get_struct_linear(a, S, Val(2), 12, ss) == 12.0 + @test_throws BoundsError debug_get_struct_linear(a, S, Val(2), 13, ss; expect_test_throws=true) +end + +@testset "get_struct - IFH indexing" begin + FT = Float64 + S = Foo{FT} + s_array = (3, 2, 4) + @test ncomponents(FT, S) == 2 + a = one_to_n(s_array, FT) + ss = DataLayouts.StaticSize(s_array, 2) + @test debug_get_struct_linear(a, S, Val(2), 1, ss) == Foo{FT}(1.0, 4.0) + @test debug_get_struct_linear(a, S, Val(2), 2, ss) == Foo{FT}(2.0, 5.0) + @test debug_get_struct_linear(a, S, Val(2), 3, ss) == Foo{FT}(3.0, 6.0) + @test debug_get_struct_linear(a, S, Val(2), 4, ss) == Foo{FT}(7.0, 10.0) + @test debug_get_struct_linear(a, S, Val(2), 5, ss) == Foo{FT}(8.0, 11.0) + @test debug_get_struct_linear(a, S, Val(2), 6, ss) == Foo{FT}(9.0, 12.0) + @test debug_get_struct_linear(a, S, Val(2), 7, ss) == Foo{FT}(13.0, 16.0) + @test debug_get_struct_linear(a, S, Val(2), 8, ss) == Foo{FT}(14.0, 17.0) + @test debug_get_struct_linear(a, S, Val(2), 9, ss) == Foo{FT}(15.0, 18.0) + @test debug_get_struct_linear(a, S, Val(2), 10, ss) == Foo{FT}(19.0, 22.0) + @test debug_get_struct_linear(a, S, Val(2), 11, ss) == Foo{FT}(20.0, 23.0) + @test debug_get_struct_linear(a, S, Val(2), 12, ss) == Foo{FT}(21.0, 24.0) + @test_throws BoundsError debug_get_struct_linear(a, S, Val(2), 13, ss; expect_test_throws=true) +end + +@testset "get_struct - IJF indexing" begin + FT = Float64 + S = Foo{FT} + s_array = (3, 4, 2) + @test ncomponents(FT, S) == 2 + s = field_dim_to_one(s_array, 3) + a = one_to_n(s_array, FT) + ss = DataLayouts.StaticSize(s_array, 3) + @test debug_get_struct_linear(a, S, Val(3), 1, ss) == Foo{FT}(1.0, 13.0) + @test debug_get_struct_linear(a, S, Val(3), 2, ss) == Foo{FT}(2.0, 14.0) + @test debug_get_struct_linear(a, S, Val(3), 3, ss) == Foo{FT}(3.0, 15.0) + @test debug_get_struct_linear(a, S, Val(3), 4, ss) == Foo{FT}(4.0, 16.0) + @test debug_get_struct_linear(a, S, Val(3), 5, ss) == Foo{FT}(5.0, 17.0) + @test debug_get_struct_linear(a, S, Val(3), 6, ss) == Foo{FT}(6.0, 18.0) + @test debug_get_struct_linear(a, S, Val(3), 7, ss) == Foo{FT}(7.0, 19.0) + @test debug_get_struct_linear(a, S, Val(3), 8, ss) == Foo{FT}(8.0, 20.0) + @test debug_get_struct_linear(a, S, Val(3), 9, ss) == Foo{FT}(9.0, 21.0) + @test debug_get_struct_linear(a, S, Val(3), 10, ss) == Foo{FT}(10.0, 22.0) + @test debug_get_struct_linear(a, S, Val(3), 11, ss) == Foo{FT}(11.0, 23.0) + @test debug_get_struct_linear(a, S, Val(3), 12, ss) == Foo{FT}(12.0, 24.0) + @test_throws BoundsError debug_get_struct_linear(a, S, Val(3), 13, ss; expect_test_throws=true) +end + +@testset "get_struct - VIJFH indexing" begin + FT = Float64 + S = Foo{FT} + s_array = (2,2,2,2,2) + @test ncomponents(FT, S) == 2 + s = field_dim_to_one(s_array, 4) + a = one_to_n(s_array, FT) + ss = DataLayouts.StaticSize(s_array, 4) + + @test debug_get_struct_linear(a, S, Val(4), 1, ss) == Foo{FT}(1.0, 9.0) + @test debug_get_struct_linear(a, S, Val(4), 2, ss) == Foo{FT}(2.0, 10.0) + @test debug_get_struct_linear(a, S, Val(4), 3, ss) == Foo{FT}(3.0, 11.0) + @test debug_get_struct_linear(a, S, Val(4), 4, ss) == Foo{FT}(4.0, 12.0) + @test debug_get_struct_linear(a, S, Val(4), 5, ss) == Foo{FT}(5.0, 13.0) + @test debug_get_struct_linear(a, S, Val(4), 6, ss) == Foo{FT}(6.0, 14.0) + @test debug_get_struct_linear(a, S, Val(4), 7, ss) == Foo{FT}(7.0, 15.0) + @test debug_get_struct_linear(a, S, Val(4), 8, ss) == Foo{FT}(8.0, 16.0) + @test debug_get_struct_linear(a, S, Val(4), 9, ss) == Foo{FT}(17.0, 25.0) + @test debug_get_struct_linear(a, S, Val(4), 10, ss) == Foo{FT}(18.0, 26.0) + @test debug_get_struct_linear(a, S, Val(4), 11, ss) == Foo{FT}(19.0, 27.0) + @test debug_get_struct_linear(a, S, Val(4), 12, ss) == Foo{FT}(20.0, 28.0) + @test debug_get_struct_linear(a, S, Val(4), 13, ss) == Foo{FT}(21.0, 29.0) + @test debug_get_struct_linear(a, S, Val(4), 14, ss) == Foo{FT}(22.0, 30.0) + @test debug_get_struct_linear(a, S, Val(4), 15, ss) == Foo{FT}(23.0, 31.0) + @test debug_get_struct_linear(a, S, Val(4), 16, ss) == Foo{FT}(24.0, 32.0) + @test_throws BoundsError debug_get_struct_linear(a, S, Val(4), 17, ss; expect_test_throws=true) +end + +# # TODO: add set_struct!