Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add linear index support in get_struct #1919

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions src/DataLayouts/struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,90 @@ Base.@propagate_inbounds function get_struct(
@inbounds return array[start_index]
end


abstract type _Size end
struct DynamicSize <: _Size end
struct StaticSize{S_array, FD} <: _Size
function StaticSize{S, FD}() where {S, FD}
new{S::Tuple{Vararg{Int}}, FD}()
end
end

Base.@pure StaticSize(s::Tuple{Vararg{Int}}, FD) = StaticSize{s, FD}()

# Some @pure convenience functions for `StaticSize`
s_field_dim_1(::Type{StaticSize{S, FD}}) where {S, FD} = Tuple(map(j-> j == FD ? 1 : S[j], 1:length(S)))
s_field_dim_1(::StaticSize{S, FD}) where {S, FD} = Tuple(map(j-> j == FD ? 1 : S[j], 1:length(S)))

Base.@pure get(::Type{StaticSize{S}}) where {S} = S
Base.@pure get(::StaticSize{S}) where {S} = S
Base.@pure Base.getindex(::StaticSize{S}, i::Int) where {S} = i <= length(S) ? S[i] : 1
Base.@pure Base.ndims(::StaticSize{S}) where {S} = length(S)
Base.@pure Base.ndims(::Type{StaticSize{S}}) where {S} = length(S)
Base.@pure Base.length(::StaticSize{S}) where {S} = prod(S)

Base.@propagate_inbounds cart_ind(n::NTuple, i::Integer) =
@inbounds CartesianIndices(map(x -> Base.OneTo(x), n))[i]
Base.@propagate_inbounds linear_ind(n::NTuple) =
@inbounds LinearIndices(map(x -> Base.OneTo(x), n))

include("to_linear_index.jl") # TODO: delete if not needed

@inline function offset_index(
start_index::Integer,
::Val{D},
field_offset,
ss::StaticSize{SS};
) where {D, SS}
# TODO: compute this offset directly without going through CartesianIndex
SS1 = s_field_dim_1(typeof(ss))
ci = cart_ind(SS1, start_index)
ci_poff = CartesianIndex(ntuple(n -> n == D ? ci[n] + field_offset : ci[n], ndims(ss)))
return linear_ind(SS)[ci_poff]
end

Base.@propagate_inbounds @generated function get_struct_linear(
array::AbstractArray{T},
::Type{S},
::Val{D},
start_index::Integer,
ss::StaticSize;
) where {T, S, D}
tup = :(())
for i in 1:fieldcount(S)
push!(
tup.args,
:(get_struct_linear(
array,
fieldtype(S, $i),
Val($D),
offset_index(
start_index,
Val($D),
$(fieldtypeoffset(T, S, Val(i))),
ss
),
ss
)),
)
end
return quote
Base.@_propagate_inbounds_meta
@inbounds bypass_constructor(S, $tup)
end
end

# recursion base case: hit array type is the same as the struct leaf type
Base.@propagate_inbounds function get_struct_linear(
array::AbstractArray{S},
::Type{S},
::Val{D},
start_index::Integer,
us::StaticSize
) where {S, D}
@inbounds return array[start_index]
end

"""
set_struct!(array, val::S, Val(D), start_index)

Expand Down
39 changes: 39 additions & 0 deletions src/DataLayouts/to_linear_index.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
_to_linear_index(A::AbstractArray, li, ci) = _to_linear_index(A, Base.to_indices(li, (ci,))...)
_to_linear_index(A::AbstractArray, I::Integer...) = (@inline; _sub2ind(A, I...))

function _sub2ind(A::AbstractArray, I...)
@inline
_sub2ind(axes(A), I...)
end

# 0-dimensional arrays and indexing with []
_sub2ind(::Tuple{}) = 1
_sub2ind(::Base.DimsInteger) = 1
# _sub2ind(::Indices) = 1
_sub2ind(::Tuple{}, I::Integer...) = (@inline; _sub2ind_recurse((), 1, 1, I...))

# Generic cases
_sub2ind(dims::Base.DimsInteger, I::Integer...) = (@inline; _sub2ind_recurse(dims, 1, 1, I...))
_sub2ind(inds::Base.Indices, I::Integer...) = (@inline; _sub2ind_recurse(inds, 1, 1, I...))
# In 1d, there's a question of whether we're doing cartesian indexing
# or linear indexing. Support only the former.
_sub2ind(inds::Base.Indices{1}, I::Integer...) =
throw(ArgumentError("Linear indexing is not defined for one-dimensional arrays"))
_sub2ind(inds::Tuple{Base.OneTo}, I::Integer...) = (@inline; _sub2ind_recurse(inds, 1, 1, I...)) # only OneTo is safe
_sub2ind(inds::Tuple{Base.OneTo}, i::Integer) = i

_sub2ind_recurse(::Any, L, ind) = ind
function _sub2ind_recurse(::Tuple{}, L, ind, i::Integer, I::Integer...)
@inline
_sub2ind_recurse((), L, ind+(i-1)*L, I...)
end
function _sub2ind_recurse(inds, L, ind, i::Integer, I::Integer...)
@inline
r1 = inds[1]
_sub2ind_recurse(Base.tail(inds), nextL(L, r1), ind+offsetin(i, r1)*L, I...)
end

nextL(L, l::Integer) = L*l
nextL(L, r::AbstractUnitRange) = L*length(r)
offsetin(i, l::Integer) = i-1
offsetin(i, r::AbstractUnitRange) = i-first(r)
159 changes: 159 additions & 0 deletions test/DataLayouts/unit_linear_indexing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#=
julia --check-bounds=yes --project
using Revise; include(joinpath("test", "DataLayouts", "unit_linear_indexing.jl"))
=#
using Test
using ClimaCore.DataLayouts
using ClimaCore.DataLayouts: get_struct_linear
import ClimaCore.Geometry
# import ClimaComms
using StaticArrays
# ClimaComms.@import_required_backends
import Random
Random.seed!(1234)

offset_indices(::Type{FT},
::Type{S},
::Val{D},
start_index::Integer,
ss::DataLayouts.StaticSize
) where {FT, S, D} = map(i-> DL.offset_index(
start_index,
Val(D),
DL.fieldtypeoffset(FT, S, Val(i)),
ss
), 1:fieldcount(S))
import ClimaCore.DataLayouts as DL
field_dim_to_one(s, dim) = Tuple(map(j-> j == dim ? 1 : s[j], 1:length(s)))

Base.@propagate_inbounds cart_ind(n::NTuple, i::Integer) =
@inbounds CartesianIndices(map(x -> Base.OneTo(x), n))[i]
Base.@propagate_inbounds linear_ind(n::NTuple, ci::CartesianIndex) =
@inbounds LinearIndices(map(x -> Base.OneTo(x), n))[ci]
Base.@propagate_inbounds linear_ind(n::NTuple, loc::NTuple) =
linear_ind(n, CartesianIndex(loc))

function debug_get_struct_linear(args...; expect_test_throws=false)
if expect_test_throws
get_struct_linear(args...)
else
try
get_struct_linear(args...)
catch
get_struct_linear(args...)
end
end
end

function one_to_n(a::Array)
for i in 1:length(a)
a[i] = i
end
return a
end
one_to_n(s::Tuple, ::Type{FT}) where {FT} = one_to_n(zeros(FT, s...))
ncomponents(::Type{FT}, ::Type{S}) where {FT, S} = div(sizeof(S), sizeof(FT))

struct Foo{T}
x::T
y::T
end

Base.zero(::Type{Foo{T}}) where {T} = Foo{T}(0, 0)

@testset "get_struct - IFH indexing (float)" begin
FT = Float64
S = FT
s_array = (3, 1, 4)
@test ncomponents(FT, S) == 1
a = one_to_n(s_array, FT)
ss = DataLayouts.StaticSize(s_array, 2)
@test debug_get_struct_linear(a, S, Val(2), 1, ss) == 1.0
@test debug_get_struct_linear(a, S, Val(2), 2, ss) == 2.0
@test debug_get_struct_linear(a, S, Val(2), 3, ss) == 3.0
@test debug_get_struct_linear(a, S, Val(2), 4, ss) == 4.0
@test debug_get_struct_linear(a, S, Val(2), 5, ss) == 5.0
@test debug_get_struct_linear(a, S, Val(2), 6, ss) == 6.0
@test debug_get_struct_linear(a, S, Val(2), 7, ss) == 7.0
@test debug_get_struct_linear(a, S, Val(2), 8, ss) == 8.0
@test debug_get_struct_linear(a, S, Val(2), 9, ss) == 9.0
@test debug_get_struct_linear(a, S, Val(2), 10, ss) == 10.0
@test debug_get_struct_linear(a, S, Val(2), 11, ss) == 11.0
@test debug_get_struct_linear(a, S, Val(2), 12, ss) == 12.0
@test_throws BoundsError debug_get_struct_linear(a, S, Val(2), 13, ss; expect_test_throws=true)
end

@testset "get_struct - IFH indexing" begin
FT = Float64
S = Foo{FT}
s_array = (3, 2, 4)
@test ncomponents(FT, S) == 2
a = one_to_n(s_array, FT)
ss = DataLayouts.StaticSize(s_array, 2)
@test debug_get_struct_linear(a, S, Val(2), 1, ss) == Foo{FT}(1.0, 4.0)
@test debug_get_struct_linear(a, S, Val(2), 2, ss) == Foo{FT}(2.0, 5.0)
@test debug_get_struct_linear(a, S, Val(2), 3, ss) == Foo{FT}(3.0, 6.0)
@test debug_get_struct_linear(a, S, Val(2), 4, ss) == Foo{FT}(7.0, 10.0)
@test debug_get_struct_linear(a, S, Val(2), 5, ss) == Foo{FT}(8.0, 11.0)
@test debug_get_struct_linear(a, S, Val(2), 6, ss) == Foo{FT}(9.0, 12.0)
@test debug_get_struct_linear(a, S, Val(2), 7, ss) == Foo{FT}(13.0, 16.0)
@test debug_get_struct_linear(a, S, Val(2), 8, ss) == Foo{FT}(14.0, 17.0)
@test debug_get_struct_linear(a, S, Val(2), 9, ss) == Foo{FT}(15.0, 18.0)
@test debug_get_struct_linear(a, S, Val(2), 10, ss) == Foo{FT}(19.0, 22.0)
@test debug_get_struct_linear(a, S, Val(2), 11, ss) == Foo{FT}(20.0, 23.0)
@test debug_get_struct_linear(a, S, Val(2), 12, ss) == Foo{FT}(21.0, 24.0)
@test_throws BoundsError debug_get_struct_linear(a, S, Val(2), 13, ss; expect_test_throws=true)
end

@testset "get_struct - IJF indexing" begin
FT = Float64
S = Foo{FT}
s_array = (3, 4, 2)
@test ncomponents(FT, S) == 2
s = field_dim_to_one(s_array, 3)
a = one_to_n(s_array, FT)
ss = DataLayouts.StaticSize(s_array, 3)
@test debug_get_struct_linear(a, S, Val(3), 1, ss) == Foo{FT}(1.0, 13.0)
@test debug_get_struct_linear(a, S, Val(3), 2, ss) == Foo{FT}(2.0, 14.0)
@test debug_get_struct_linear(a, S, Val(3), 3, ss) == Foo{FT}(3.0, 15.0)
@test debug_get_struct_linear(a, S, Val(3), 4, ss) == Foo{FT}(4.0, 16.0)
@test debug_get_struct_linear(a, S, Val(3), 5, ss) == Foo{FT}(5.0, 17.0)
@test debug_get_struct_linear(a, S, Val(3), 6, ss) == Foo{FT}(6.0, 18.0)
@test debug_get_struct_linear(a, S, Val(3), 7, ss) == Foo{FT}(7.0, 19.0)
@test debug_get_struct_linear(a, S, Val(3), 8, ss) == Foo{FT}(8.0, 20.0)
@test debug_get_struct_linear(a, S, Val(3), 9, ss) == Foo{FT}(9.0, 21.0)
@test debug_get_struct_linear(a, S, Val(3), 10, ss) == Foo{FT}(10.0, 22.0)
@test debug_get_struct_linear(a, S, Val(3), 11, ss) == Foo{FT}(11.0, 23.0)
@test debug_get_struct_linear(a, S, Val(3), 12, ss) == Foo{FT}(12.0, 24.0)
@test_throws BoundsError debug_get_struct_linear(a, S, Val(3), 13, ss; expect_test_throws=true)
end

@testset "get_struct - VIJFH indexing" begin
FT = Float64
S = Foo{FT}
s_array = (2,2,2,2,2)
@test ncomponents(FT, S) == 2
s = field_dim_to_one(s_array, 4)
a = one_to_n(s_array, FT)
ss = DataLayouts.StaticSize(s_array, 4)

@test debug_get_struct_linear(a, S, Val(4), 1, ss) == Foo{FT}(1.0, 9.0)
@test debug_get_struct_linear(a, S, Val(4), 2, ss) == Foo{FT}(2.0, 10.0)
@test debug_get_struct_linear(a, S, Val(4), 3, ss) == Foo{FT}(3.0, 11.0)
@test debug_get_struct_linear(a, S, Val(4), 4, ss) == Foo{FT}(4.0, 12.0)
@test debug_get_struct_linear(a, S, Val(4), 5, ss) == Foo{FT}(5.0, 13.0)
@test debug_get_struct_linear(a, S, Val(4), 6, ss) == Foo{FT}(6.0, 14.0)
@test debug_get_struct_linear(a, S, Val(4), 7, ss) == Foo{FT}(7.0, 15.0)
@test debug_get_struct_linear(a, S, Val(4), 8, ss) == Foo{FT}(8.0, 16.0)
@test debug_get_struct_linear(a, S, Val(4), 9, ss) == Foo{FT}(17.0, 25.0)
@test debug_get_struct_linear(a, S, Val(4), 10, ss) == Foo{FT}(18.0, 26.0)
@test debug_get_struct_linear(a, S, Val(4), 11, ss) == Foo{FT}(19.0, 27.0)
@test debug_get_struct_linear(a, S, Val(4), 12, ss) == Foo{FT}(20.0, 28.0)
@test debug_get_struct_linear(a, S, Val(4), 13, ss) == Foo{FT}(21.0, 29.0)
@test debug_get_struct_linear(a, S, Val(4), 14, ss) == Foo{FT}(22.0, 30.0)
@test debug_get_struct_linear(a, S, Val(4), 15, ss) == Foo{FT}(23.0, 31.0)
@test debug_get_struct_linear(a, S, Val(4), 16, ss) == Foo{FT}(24.0, 32.0)
@test_throws BoundsError debug_get_struct_linear(a, S, Val(4), 17, ss; expect_test_throws=true)
end

# # TODO: add set_struct!
Loading