Skip to content

Commit b4a2ed5

Browse files
wip
1 parent cacbb5c commit b4a2ed5

File tree

9 files changed

+146
-117
lines changed

9 files changed

+146
-117
lines changed

ext/cuda/data_layouts.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,16 @@ import CUDA
1313
parent_array_type(::Type{<:CUDA.CuArray{T, N, B} where {N}}) where {T, B} =
1414
CUDA.CuArray{T, N, B} where {N}
1515

16+
# Can we remove this?
17+
# parent_array_type(
18+
# ::Type{<:CUDA.CuArray{T, N, B} where {N}},
19+
# ::Val{ND},
20+
# ) where {T, B, ND} = CUDA.CuArray{T, ND, B}
21+
1622
parent_array_type(
1723
::Type{<:CUDA.CuArray{T, N, B} where {N}},
18-
::Val{ND},
19-
) where {T, B, ND} = CUDA.CuArray{T, ND, B}
24+
as::ArraySize,
25+
) where {T, B} = CUDA.CuArray{T, ndims(as), B}
2026

2127
# Ensure that both parent array types have the same memory buffer type.
2228
promote_parent_array_type(

src/DataLayouts/DataLayouts.jl

+52-35
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ abstract type AbstractData{S} end
5151
@inline Base.size(data::AbstractData, i::Integer) = size(data)[i]
5252
@inline Base.size(data::AbstractData) = universal_size(data)
5353

54+
struct ArraySize{FD, Nf, S} end
55+
@inline ArraySize(data::AbstractData, i::Integer) = ArraySize(data)[i]
56+
@inline ArraySize(data::AbstractData) = ArraySize{field_dim(data), ncomponents(data), farray_size(data)}()
57+
@inline Base.ndims(::ArraySize{FD, Nf, S}) where {FD, Nf, S} = length(S)
58+
@inline Base.ndims(::Type{ArraySize{FD, Nf, S}}) where {FD, Nf, S} = length(S)
59+
5460
include("field_array.jl")
5561
include("struct.jl")
5662

@@ -296,39 +302,6 @@ end
296302
))
297303
end
298304

299-
for DL in (
300-
:IJKFVH,
301-
:IJFH,
302-
:IFH,
303-
# :DataF, # we want to MethodError for DataF
304-
:IJF,
305-
:IF,
306-
:VF,
307-
:VIJFH,
308-
:VIFH,
309-
:IH1JH2,
310-
:IV1JH2,
311-
)
312-
@eval @propagate_inbounds function Base.getindex(
313-
data::$DL{S},
314-
I::Integer,
315-
) where {S}
316-
@inbounds get_struct(parent(data), S, Val(field_dim(data)), I)
317-
end
318-
@eval @propagate_inbounds function Base.setindex!(
319-
data::$DL{S},
320-
val,
321-
I::Integer,
322-
) where {S}
323-
@inbounds set_struct!(
324-
parent(data),
325-
convert(S, val),
326-
Val(field_dim(data)),
327-
I,
328-
)
329-
end
330-
end
331-
332305
function replace_basetype(data::AbstractData{S}, ::Type{T}) where {S, T}
333306
array = parent(data)
334307
S′ = replace_basetype(eltype(array), T, S)
@@ -540,6 +513,7 @@ Base.@propagate_inbounds slab(data::IFH, v::Integer, h::Integer) = slab(data, h)
540513
@inline function column(data::IFH{S, Ni}, i, h) where {S, Ni}
541514
@boundscheck (1 <= h <= get_Nh(data) && 1 <= i <= Ni) ||
542515
throw(BoundsError(data, (i, h)))
516+
fa = parent(data)
543517
dataview = @inbounds FieldArray{field_dim(DataF)}(
544518
ntuple(jf -> view(parent(fa.arrays[jf]), i, h), ncomponents(fa)),
545519
)
@@ -731,6 +705,7 @@ end
731705
@inline function column(data::IJF{S, Nij}, i, j) where {S, Nij}
732706
@boundscheck (1 <= j <= Nij && 1 <= i <= Nij) ||
733707
throw(BoundsError(data, (i, j)))
708+
fa = parent(data)
734709
dataview = @inbounds FieldArray{field_dim(DataF)}(
735710
ntuple(jf -> view(parent(fa.arrays[jf]), i, j), ncomponents(fa)),
736711
)
@@ -905,8 +880,13 @@ end
905880

906881
@inline function level(data::VF{S}, v) where {S}
907882
@boundscheck (1 <= v <= nlevels(data)) || throw(BoundsError(data, (v)))
908-
array = parent(data)
909-
dataview = @inbounds view(array, v, :)
883+
fa = parent(data)
884+
dataview = @inbounds FieldArray{field_dim(DataF)}(
885+
ntuple(ncomponents(fa)) do jf
886+
view(parent(fa.arrays[jf]), v, :)
887+
end,
888+
)
889+
910890
DataF{S}(dataview)
911891
end
912892

@@ -1247,6 +1227,9 @@ Adapt.adapt_structure(to, data::AbstractData{S}) where {S} =
12471227
rebuild(data::AbstractData, array::AbstractArray) =
12481228
union_all(data){type_params(data)...}(array)
12491229

1230+
rebuild(data::AbstractData, fa::FieldArray) =
1231+
union_all(data){type_params(data)...}(fa)
1232+
12501233
empty_kernel_stats(::ClimaComms.AbstractDevice) = nothing
12511234
empty_kernel_stats() = empty_kernel_stats(ClimaComms.device())
12521235

@@ -1459,6 +1442,40 @@ type parameters.
14591442

14601443
#! format: on
14611444

1445+
# Skip DataF here, since we want that to MethodError.
1446+
for DL in (:IJKFVH, :IJFH, :IFH, :IJF, :IF, :VF, :VIJFH, :VIFH)
1447+
@eval @propagate_inbounds Base.getindex(data::$(DL), I::Integer) =
1448+
linear_getindex(data, I)
1449+
@eval @propagate_inbounds Base.setindex!(data::$(DL), val, I::Integer) =
1450+
linear_setindex!(data, val, I)
1451+
end
1452+
1453+
# Datalayouts
1454+
@propagate_inbounds function linear_getindex(
1455+
data::AbstractData{S},
1456+
I::Integer,
1457+
) where {S}
1458+
s_array = farray_size(data)
1459+
ss = StaticSize(s_array, field_dim(data))
1460+
@inbounds get_struct_linear(parent(data), S, Val(field_dim(data)), ss, I)
1461+
end
1462+
@propagate_inbounds function linear_setindex!(
1463+
data::AbstractData{S},
1464+
val,
1465+
I::Integer,
1466+
) where {S}
1467+
s_array = farray_size(data)
1468+
ss = StaticSize(s_array, field_dim(data))
1469+
@inbounds set_struct_linear!(
1470+
parent(data),
1471+
convert(S, val),
1472+
Val(field_dim(data)),
1473+
ss,
1474+
I,
1475+
)
1476+
end
1477+
1478+
14621479
Base.ndims(data::AbstractData) = Base.ndims(typeof(data))
14631480
Base.ndims(::Type{T}) where {T <: AbstractData} =
14641481
Base.ndims(field_array_type(T))

src/DataLayouts/broadcast.jl

+10-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Data0DStyle(::Type{DataFStyle{A}}) where {A} = DataFStyle{A}
1919
abstract type DataColumnStyle <: DataStyle end
2020
struct VFStyle{Nv, A} <: DataColumnStyle end
2121
DataStyle(::Type{VF{S, Nv, A}}) where {S, Nv, A} =
22-
VFStyle{Nv, to_parent_array_type(A)}()
22+
VFStyle{Nv, parent_array_type(A)}()
2323
DataColumnStyle(::Type{VFStyle{Nv, A}}) where {Nv, A} = VFStyle{Nv, A}
2424

2525
abstract type Data1DStyle{Ni} <: DataStyle end
@@ -42,7 +42,7 @@ DataStyle(::Type{IJF{S, Nij, A}}) where {S, Nij, A} =
4242
abstract type Data2DStyle{Nij} <: DataStyle end
4343
struct IJFHStyle{Nij, Nh, A} <: Data2DStyle{Nij} end
4444
DataStyle(::Type{IJFH{S, Nij, Nh, A}}) where {S, Nij, Nh, A} =
45-
IJFHStyle{Nij, Nh, to_parent_array_type(A)}()
45+
IJFHStyle{Nij, Nh, parent_array_type(A)}()
4646
DataSlab2DStyle(::Type{IJFHStyle{Nij, Nh, A}}) where {Nij, Nh, A} =
4747
IJFStyle{Nij, A}
4848

@@ -378,8 +378,14 @@ function Base.similar(
378378
) where {Nv, Nij, Nh, A, Eltype, newNv}
379379
T = eltype(A)
380380
Nf = typesize(eltype(A), Eltype)
381-
fat = field_array_type(A, Val(field_dim(VIJFH)), Val(Nf), Val(4))
382-
array = similar(fat, Base.Dims((newNv, Nij, Nij, Nh)))
381+
# fat = rebuild_type(A, Val(field_dim(VIJFH)), Val(Nf), Val(4))
382+
_size = (newNv, Nij, Nij, Nh)
383+
as = ArraySize{field_dim(VIJFH), Nf, _size}()
384+
# fat = if A isa AbstractArray
385+
# field_array_type(A, as)
386+
# else
387+
# end
388+
array = similar(rebuild_field_array_type(A, as), _size)
383389
vd = VIJFH{Eltype, newNv, Nij, Nh}(array)
384390
return vd
385391
end

src/DataLayouts/field_array.jl

+47-19
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,34 @@ float_type(::Type{FieldArray{FD, N, T}}) where {FD, N, T} = eltype(T)
4646
parent_array_type(::Type{FieldArray{FD, N, T}}) where {FD, N, T} =
4747
parent_array_type(T)
4848
# field_array_type(::Type{FieldArray{N,T}}, ::Val{Nf}) where {N,T, Nf} = FieldArray{Nf, parent_array_type(T, Val(ndims(T)))}
49-
field_array_type(
49+
# field_array_type(
50+
# ::Type{FieldArray{FD, N, T}},
51+
# ::Val{Nf},
52+
# ::Val{ND},
53+
# ) where {FD, N, T, Nf, ND} = FieldArray{FD, Nf, parent_array_type(T, Val(ND))}
54+
55+
rebuild_type(
56+
::Type{FieldArray{FD, N, T}},
57+
as::ArraySize{FD,Nf},
58+
) where {FD, N, T, Nf} = FieldArray{FD, Nf, parent_array_type(T, Val(ndims(as)))}
59+
60+
rebuild_field_array_type(
5061
::Type{FieldArray{FD, N, T}},
51-
::Val{Nf},
52-
::Val{ND},
53-
) where {FD, N, T, Nf, ND} = FieldArray{FD, Nf, parent_array_type(T, Val(ND))}
54-
# field_array_type(::Type{T}, ::Val{Nf}) where {T <: AbstractArray, Nf} = FieldArray{Nf, parent_array_type(T, Val(ndims(T)))}
55-
field_array_type(
62+
as::ArraySize{FD,Nf},
63+
) where {FD, N, T, Nf} = FieldArray{FD, Nf, parent_array_type(T, Val(ndims(as)))}
64+
65+
rebuild_field_array_type(
5666
::Type{T},
57-
::Val{FD},
58-
::Val{Nf},
59-
::Val{ND},
60-
) where {T <: AbstractArray, FD, Nf, ND} =
61-
FieldArray{FD, Nf, parent_array_type(T, Val(ND))}
67+
as::ArraySize{FD,Nf,S},
68+
) where {FD, T<:AbstractArray, Nf, S} = FieldArray{FD, Nf, parent_array_type(T, as)}
69+
70+
# field_array_type(
71+
# ::Type{T},
72+
# ::Val{FD},
73+
# ::Val{Nf},
74+
# ::Val{ND},
75+
# ) where {T <: AbstractArray, FD, Nf, ND} =
76+
# FieldArray{FD, Nf, parent_array_type(T, Val(ND))}
6277
Base.ndims(::Type{FieldArray{FD, N, T}}) where {FD, N, T} = Base.ndims(T) + 1
6378
Base.eltype(::Type{FieldArray{FD, N, T}}) where {FD, N, T} = eltype(T)
6479
array_type(::Type{FieldArray{FD, N, T}}) where {FD, N, T} = T
@@ -101,8 +116,16 @@ end
101116
end
102117
end
103118

104-
Base.similar(fa::FieldArray{FD, N, T}) where {FD, N, T} =
105-
FieldArray{FD, N, T}(ntuple(i -> similar(T, ndims(T)), N))
119+
Base.similar(fa::FieldArray{FD, N, T}, dims) where {FD, N, T} =
120+
FieldArray{FD, N, T}(ntuple(i -> similar(T, dims), N))
121+
function Base.similar(::Type{FieldArray{FD, N, T}}, dims) where {FD, N, T}
122+
FieldArray{FD, N, T}(ntuple(i -> similar(T, dims), N))
123+
end
124+
125+
function Base.similar(::Type{FieldArray{FD, N, T}}) where {FD, N, T}
126+
isconcretetype(T) || error("Array type $T is not concrete, pass `dims` to similar or use concrete array type.")
127+
FieldArray{FD, N, T}(ntuple(i -> similar(T), N))
128+
end
106129

107130
@inline insertafter(t::Tuple, i::Int, j::Int) =
108131
0 <= i <= length(t) ? _insertafter(t, i, j) : throw(BoundsError(t, i))
@@ -124,13 +147,12 @@ function Base.collect(fa::FieldArray{FD, N, T}) where {FD, N, T}
124147
return a
125148
end
126149

127-
function Base.similar(::Type{<:FieldArray{FD, N, T}}, s) where {FD, N, T}
128-
FieldArray{FD, N}(ntuple(i -> similar(T, s), N))
129-
end
130-
131150
field_array(array::AbstractArray, s::AbstractDataLayoutSingleton) =
132151
field_array(array, field_dim(s))
133152

153+
field_arrays(fa::FieldArray) = getfield(fa, :arrays)
154+
field_arrays(data::AbstractData) = field_arrays(parent(data))
155+
134156
FieldArray{FD, N}(
135157
fa::FA,
136158
) where {FD, N, T <: AbstractArray, FA <: NTuple{N, T}} =
@@ -164,7 +186,13 @@ function field_array(array::AbstractArray, fdim::Integer)
164186
return FieldArray{fdim, Nf, eltype(arrays)}(arrays)
165187
end
166188

167-
function Base.getindex(fa::FieldArray{FD}, I::CartesianIndex) where {FD}
189+
function Base.:(==)(fa::FieldArray, array::AbstractArray)
190+
return collect(fa) == array
191+
end
192+
193+
Base.getindex(fa::FieldArray, I::Integer...) = getindex(fa, CartesianIndex(I))
194+
195+
function Base.getindex(fa::FieldArray{FD}, I::CartesianIndex{5}) where {FD}
168196
FDI = I.I[FD]
169197
ND = length(I.I)
170198
Ipre = ntuple(i -> I.I[i], Val(FD - 1))
@@ -173,7 +201,7 @@ function Base.getindex(fa::FieldArray{FD}, I::CartesianIndex) where {FD}
173201
return fa.arrays[FDI][IA]
174202
end
175203

176-
function Base.setindex!(fa::FieldArray{FD}, val, I::CartesianIndex) where {FD}
204+
function Base.setindex!(fa::FieldArray{FD}, val, I::CartesianIndex{5}) where {FD}
177205
FDI = I.I[FD]
178206
ND = length(I.I)
179207
Ipre = ntuple(i -> I.I[i], Val(FD - 1))

src/DataLayouts/non_extruded_broadcasted.jl

-31
Original file line numberDiff line numberDiff line change
@@ -127,34 +127,3 @@ Base.IndexStyle(::Type{<:NonExtrudedBroadcasted{<:Any, <:Tuple{Any}}}) = IndexLi
127127
# ============================================================
128128

129129
#! format: on
130-
# Datalayouts
131-
@propagate_inbounds function linear_getindex(
132-
data::AbstractData{S},
133-
I::Integer,
134-
) where {S}
135-
s_array = farray_size(data)
136-
ss = StaticSize(s_array, field_dim(data))
137-
@inbounds get_struct_linear(parent(data), S, Val(field_dim(data)), ss, I)
138-
end
139-
@propagate_inbounds function linear_setindex!(
140-
data::AbstractData{S},
141-
val,
142-
I::Integer,
143-
) where {S}
144-
s_array = farray_size(data)
145-
ss = StaticSize(s_array, field_dim(data))
146-
@inbounds set_struct_linear!(
147-
parent(data),
148-
convert(S, val),
149-
Val(field_dim(data)),
150-
ss,
151-
I,
152-
)
153-
end
154-
155-
for DL in (:IJKFVH, :IJFH, :IFH, :IJF, :IF, :VF, :VIJFH, :VIFH) # Skip DataF, since we want that to MethodError.
156-
@eval @propagate_inbounds Base.getindex(data::$(DL), I::Integer) =
157-
linear_getindex(data, I)
158-
@eval @propagate_inbounds Base.setindex!(data::$(DL), val, I::Integer) =
159-
linear_setindex!(data, val, I)
160-
end

src/DataLayouts/struct.jl

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ Returns the parent array type underlying any wrapper types, with all
8787
dimensionality information removed.
8888
"""
8989
parent_array_type(::Type{<:Array{T}}) where {T} = Array{T}
90+
parent_array_type(::Type{<:Array{T}}, as::ArraySize) where {T} = Array{T, ndims(as)}
9091
parent_array_type(::Type{<:MArray{S, T, N, L}}) where {S, T, N, L} =
9192
MArray{S, T}
9293
parent_array_type(::Type{<:SubArray{T, N, A}}) where {T, N, A} =

test/DataLayouts/data0d.jl

+12-12
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ end
129129
FT = Float64
130130
array = rand(FT, 2, 1)
131131
data = IF{FT, 2}(array)
132-
@test DataLayouts.data2array(data) == reshape(parent(data), :)
133-
@test parent(DataLayouts.array2data(DataLayouts.data2array(data), data)) ==
134-
parent(data)
132+
@test DataLayouts.data2array(data) == reshape(collect(parent(data)), :)
133+
@test size(collect(parent(DataLayouts.array2data(DataLayouts.data2array(data), data)))) ==
134+
size(collect(parent(data)))
135135
end
136136

137137
@testset "broadcasting DataF + IFH data object => IFH" begin
@@ -148,9 +148,9 @@ end
148148
FT = Float64
149149
array = rand(FT, 2, 1, Nh)
150150
data = IFH{FT, 2, Nh}(array)
151-
@test DataLayouts.data2array(data) == reshape(parent(data), :)
152-
@test parent(DataLayouts.array2data(DataLayouts.data2array(data), data)) ==
153-
parent(data)
151+
@test DataLayouts.data2array(data) == reshape(collect(parent(data)), :)
152+
@test size(collect(parent(DataLayouts.array2data(DataLayouts.data2array(data), data)))) ==
153+
size(collect(parent(data)))
154154
end
155155

156156
@testset "broadcasting DataF + IJF data object => IJF" begin
@@ -165,9 +165,9 @@ end
165165
FT = Float64
166166
array = rand(FT, 2, 2, 1)
167167
data = IJF{FT, 2}(array)
168-
@test DataLayouts.data2array(data) == reshape(parent(data), :)
169-
@test parent(DataLayouts.array2data(DataLayouts.data2array(data), data)) ==
170-
parent(data)
168+
@test DataLayouts.data2array(data) == reshape(collect(parent(data)), :)
169+
@test size(collect(parent(DataLayouts.array2data(DataLayouts.data2array(data), data)))) ==
170+
size(collect(parent(data)))
171171
end
172172

173173
@testset "broadcasting DataF + IJFH data object => IJFH" begin
@@ -184,9 +184,9 @@ end
184184
Nh = 3
185185
array = rand(FT, 2, 2, 1, Nh)
186186
data = IJFH{FT, 2, Nh}(array)
187-
@test DataLayouts.data2array(data) == reshape(parent(data), :)
188-
@test parent(DataLayouts.array2data(DataLayouts.data2array(data), data)) ==
189-
parent(data)
187+
@test DataLayouts.data2array(data) == reshape(collect(parent(data)), :)
188+
@test size(collect(parent(DataLayouts.array2data(DataLayouts.data2array(data), data)))) ==
189+
size(collect(parent(data)))
190190
end
191191

192192
@testset "broadcasting DataF + VIFH data object => VIFH" begin

0 commit comments

Comments
 (0)