Skip to content

Commit 378bbcf

Browse files
Add linear index support for pointwise kernels
1 parent 3bc75d1 commit 378bbcf

14 files changed

+730
-68
lines changed

ext/cuda/data_layouts.jl

+13
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,16 @@ function Adapt.adapt_structure(
5454
end,
5555
)
5656
end
57+
58+
import Adapt
59+
import CUDA
60+
function Adapt.adapt_structure(
61+
to::CUDA.KernelAdaptor,
62+
bc::DataLayouts.NonExtrudedBroadcasted{Style},
63+
) where {Style}
64+
DataLayouts.NonExtrudedBroadcasted{Style}(
65+
adapt_f(to, bc.f),
66+
Adapt.adapt(to, bc.args),
67+
Adapt.adapt(to, bc.axes),
68+
)
69+
end

src/DataLayouts/DataLayouts.jl

+23
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,27 @@ empty_kernel_stats() = empty_kernel_stats(ClimaComms.device())
965965
@inline get_Nij(::IJF{S, Nij}) where {S, Nij} = Nij
966966
@inline get_Nij(::IF{S, Nij}) where {S, Nij} = Nij
967967

968+
# Returns the size of the backing array.
969+
@inline array_size(::IJKFVH{S, Nij, Nk, Nv, Nh}) where {S, Nij, Nk, Nv, Nh} = (Nij, Nij, Nk, 1, Nv, Nh)
970+
@inline array_size(::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = (Nij, Nij, 1, Nh)
971+
@inline array_size(::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, 1, Nh)
972+
@inline array_size(::DataF{S}) where {S} = (1,)
973+
@inline array_size(::IJF{S, Nij}) where {S, Nij} = (Nij, Nij, 1)
974+
@inline array_size(::IF{S, Ni}) where {S, Ni} = (Ni, 1)
975+
@inline array_size(::VF{S, Nv}) where {S, Nv} = (Nv, 1)
976+
@inline array_size(::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} = (Nv, Nij, Nij, 1, Nh)
977+
@inline array_size(::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} = (Nv, Ni, 1, Nh)
978+
979+
@inline farray_size(data::IJKFVH{S, Nij, Nk, Nv, Nh}) where {S, Nij, Nk, Nv, Nh} = (Nij, Nij, Nk, ncomponents(data), Nv, Nh)
980+
@inline farray_size(data::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = (Nij, Nij, ncomponents(data), Nh)
981+
@inline farray_size(data::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, ncomponents(data), Nh)
982+
@inline farray_size(data::DataF{S}) where {S} = (ncomponents(data),)
983+
@inline farray_size(data::IJF{S, Nij}) where {S, Nij} = (Nij, Nij, ncomponents(data))
984+
@inline farray_size(data::IF{S, Ni}) where {S, Ni} = (Ni, ncomponents(data))
985+
@inline farray_size(data::VF{S, Nv}) where {S, Nv} = (Nv, ncomponents(data))
986+
@inline farray_size(data::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} = (Nv, Nij, Nij, ncomponents(data), Nh)
987+
@inline farray_size(data::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} = (Nv, Ni, ncomponents(data), Nh)
988+
968989
"""
969990
field_dim(data::AbstractData)
970991
field_dim(::Type{<:AbstractData})
@@ -1216,9 +1237,11 @@ _device_dispatch(x::AbstractData) = _device_dispatch(parent(x))
12161237
_device_dispatch(x::SArray) = ToCPU()
12171238
_device_dispatch(x::MArray) = ToCPU()
12181239

1240+
include("non_extruded_broadcasted.jl")
12191241
include("copyto.jl")
12201242
include("fused_copyto.jl")
12211243
include("fill.jl")
12221244
include("mapreduce.jl")
1245+
include("has_uniform_datalayouts.jl")
12231246

12241247
end # module

src/DataLayouts/broadcast.jl

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ DataSlab2DStyle(::Type{VIJFHStyle{Nv, Nij, Nh, A}}) where {Nv, Nij, Nh, A} =
7373
#####
7474

7575
#! format: off
76+
const BroadcastedUnionData = Union{Base.Broadcast.Broadcasted{<:DataStyle}, AbstractData}
7677
const BroadcastedUnionIJFH{S, Nij, Nh, A} = Union{Base.Broadcast.Broadcasted{IJFHStyle{Nij, Nh, A}}, IJFH{S, Nij, Nh, A}}
7778
const BroadcastedUnionIFH{S, Ni, Nh, A} = Union{Base.Broadcast.Broadcasted{IFHStyle{Ni, Nh, A}}, IFH{S, Ni, Nh, A}}
7879
const BroadcastedUnionIJF{S, Nij, A} = Union{Base.Broadcast.Broadcasted{IJFStyle{Nij, A}}, IJF{S, Nij, A}}

src/DataLayouts/copyto.jl

+15-3
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,22 @@
22
##### Dispatching and edge cases
33
#####
44

5-
Base.copyto!(
6-
dest::AbstractData,
5+
function Base.copyto!(
6+
dest::AbstractData{S},
77
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
8-
) = Base.copyto!(dest, bc, device_dispatch(dest))
8+
) where {S}
9+
dev = device_dispatch(dest)
10+
if dev isa ToCPU && has_uniform_datalayouts(bc) && !(dest isa DataF)
11+
# Specialize on linear indexing case:
12+
bc′ = Base.Broadcast.instantiate(to_non_extruded_broadcasted(bc))
13+
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
14+
dest[I] = convert(S, bc′[I])
15+
end
16+
else
17+
Base.copyto!(dest, bc, device_dispatch(dest))
18+
end
19+
return dest
20+
end
921

1022
# Specialize on non-Broadcasted objects
1123
function Base.copyto!(dest::D, src::D) where {D <: AbstractData}

src/DataLayouts/fill.jl

+7-54
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,13 @@
1-
function Base.fill!(data::IJFH, val, ::ToCPU)
2-
(_, _, _, _, Nh) = size(data)
3-
@inbounds for h in 1:Nh
4-
fill!(slab(data, h), val)
1+
function Base.fill!(dest::AbstractData, val, ::ToCPU)
2+
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
3+
dest[I] = val
54
end
6-
return data
5+
return dest
76
end
87

9-
function Base.fill!(data::IFH, val, ::ToCPU)
10-
(_, _, _, _, Nh) = size(data)
11-
@inbounds for h in 1:Nh
12-
fill!(slab(data, h), val)
13-
end
14-
return data
15-
end
16-
17-
function Base.fill!(data::DataF, val, ::ToCPU)
18-
@inbounds data[] = val
19-
return data
20-
end
21-
22-
function Base.fill!(data::IJF{S, Nij}, val, ::ToCPU) where {S, Nij}
23-
@inbounds for j in 1:Nij, i in 1:Nij
24-
data[CartesianIndex(i, j, 1, 1, 1)] = val
25-
end
26-
return data
27-
end
28-
29-
function Base.fill!(data::IF{S, Ni}, val, ::ToCPU) where {S, Ni}
30-
@inbounds for i in 1:Ni
31-
data[CartesianIndex(i, 1, 1, 1, 1)] = val
32-
end
33-
return data
34-
end
35-
36-
function Base.fill!(data::VF, val, ::ToCPU)
37-
Nv = nlevels(data)
38-
@inbounds for v in 1:Nv
39-
data[CartesianIndex(1, 1, 1, v, 1)] = val
40-
end
41-
return data
42-
end
43-
44-
function Base.fill!(data::VIJFH, val, ::ToCPU)
45-
(Ni, Nj, _, Nv, Nh) = size(data)
46-
@inbounds for h in 1:Nh, v in 1:Nv
47-
fill!(slab(data, v, h), val)
48-
end
49-
return data
50-
end
51-
52-
function Base.fill!(data::VIFH, val, ::ToCPU)
53-
(Ni, _, _, Nv, Nh) = size(data)
54-
@inbounds for h in 1:Nh, v in 1:Nv
55-
fill!(slab(data, v, h), val)
56-
end
57-
return data
8+
function Base.fill!(dest::DataF, val, ::ToCPU)
9+
@inbounds dest[] = val
10+
return dest
5811
end
5912

6013
Base.fill!(dest::AbstractData, val) =
+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
@inline function first_datalayout_in_bc(args::Tuple, rargs...)
2+
x1 = first_datalayout_in_bc(args[1], rargs...)
3+
x1 isa AbstractData && return x1
4+
return first_datalayout_in_bc(Base.tail(args), rargs...)
5+
end
6+
7+
@inline first_datalayout_in_bc(args::Tuple{Any}, rargs...) =
8+
first_datalayout_in_bc(args[1], rargs...)
9+
@inline first_datalayout_in_bc(args::Tuple{}, rargs...) = nothing
10+
@inline first_datalayout_in_bc(x) = nothing
11+
@inline first_datalayout_in_bc(x::AbstractData) = x
12+
13+
@inline first_datalayout_in_bc(bc::Base.Broadcast.Broadcasted) =
14+
first_datalayout_in_bc(bc.args)
15+
16+
@inline _has_uniform_datalayouts_args(truesofar, start, args::Tuple, rargs...) =
17+
truesofar &&
18+
_has_uniform_datalayouts(truesofar, start, args[1], rargs...) &&
19+
_has_uniform_datalayouts_args(truesofar, start, Base.tail(args), rargs...)
20+
21+
@inline _has_uniform_datalayouts_args(
22+
truesofar,
23+
start,
24+
args::Tuple{Any},
25+
rargs...,
26+
) = truesofar && _has_uniform_datalayouts(truesofar, start, args[1], rargs...)
27+
@inline _has_uniform_datalayouts_args(truesofar, _, args::Tuple{}, rargs...) =
28+
truesofar
29+
30+
@inline function _has_uniform_datalayouts(
31+
truesofar,
32+
start,
33+
bc::Base.Broadcast.Broadcasted,
34+
)
35+
return truesofar && _has_uniform_datalayouts_args(truesofar, start, bc.args)
36+
end
37+
for DL in (:IJKFVH, :IJFH, :IFH, :DataF, :IJF, :IF, :VF, :VIJFH, :VIFH)
38+
@eval begin
39+
@inline _has_uniform_datalayouts(truesofar, ::$(DL), ::$(DL)) = true
40+
end
41+
end
42+
@inline _has_uniform_datalayouts(truesofar, _, x::AbstractData) = false
43+
@inline _has_uniform_datalayouts(truesofar, _, x) = truesofar
44+
45+
"""
46+
has_uniform_datalayouts
47+
Find the first datalayout in the broadcast expression (BCE),
48+
and compares against every other datalayout in the BCE. Returns
49+
- `true` if the broadcasted object has only a single kind of datalayout (e.g. VF,VF, VIJFH,VIJFH)
50+
- `false` if the broadcasted object has multiple kinds of datalayouts (e.g. VIJFH, VIFH)
51+
Note: a broadcasted object can have different _types_,
52+
e.g., `VIFJH{Float64}` and `VIFJH{Tuple{Float64,Float64}}`
53+
but not different kinds, e.g., `VIFJH{Float64}` and `VF{Float64}`.
54+
"""
55+
function has_uniform_datalayouts end
56+
57+
@inline has_uniform_datalayouts(bc::Base.Broadcast.Broadcasted) =
58+
_has_uniform_datalayouts_args(true, first_datalayout_in_bc(bc), bc.args)
59+
60+
@inline has_uniform_datalayouts(bc::AbstractData) = true
+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
#! format: off
2+
# ============================================================ Adapted from Base.Broadcast (julia version 1.10.4)
3+
import Base.Broadcast: BroadcastStyle
4+
struct NonExtrudedBroadcasted{
5+
Style <: Union{Nothing, BroadcastStyle},
6+
Axes,
7+
F,
8+
Args <: Tuple,
9+
} <: Base.AbstractBroadcasted
10+
style::Style
11+
f::F
12+
args::Args
13+
axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `NonExtrudedBroadcasted`)
14+
15+
NonExtrudedBroadcasted(style::Union{Nothing, BroadcastStyle}, f::Tuple, args::Tuple) =
16+
error() # disambiguation: tuple is not callable
17+
function NonExtrudedBroadcasted(
18+
style::Union{Nothing, BroadcastStyle},
19+
f::F,
20+
args::Tuple,
21+
axes = nothing,
22+
) where {F}
23+
# using Core.Typeof rather than F preserves inferrability when f is a type
24+
return new{typeof(style), typeof(axes), Core.Typeof(f), typeof(args)}(
25+
style,
26+
f,
27+
args,
28+
axes,
29+
)
30+
end
31+
function NonExtrudedBroadcasted(f::F, args::Tuple, axes = nothing) where {F}
32+
NonExtrudedBroadcasted(combine_styles(args...)::BroadcastStyle, f, args, axes)
33+
end
34+
function NonExtrudedBroadcasted{Style}(f::F, args, axes = nothing) where {Style, F}
35+
return new{Style, typeof(axes), Core.Typeof(f), typeof(args)}(
36+
Style()::Style,
37+
f,
38+
args,
39+
axes,
40+
)
41+
end
42+
function NonExtrudedBroadcasted{Style, Axes, F, Args}(
43+
f,
44+
args,
45+
axes,
46+
) where {Style, Axes, F, Args}
47+
return new{Style, Axes, F, Args}(Style()::Style, f, args, axes)
48+
end
49+
end
50+
51+
@inline to_non_extruded_broadcasted(bc::Base.Broadcast.Broadcasted) =
52+
NonExtrudedBroadcasted(bc.style, bc.f, to_non_extruded_broadcasted(bc.args), bc.axes)
53+
@inline to_non_extruded_broadcasted(x) = x
54+
NonExtrudedBroadcasted(bc::Base.Broadcast.Broadcasted) = to_non_extruded_broadcasted(bc)
55+
56+
@inline to_non_extruded_broadcasted(args::Tuple) = (
57+
to_non_extruded_broadcasted(args[1]),
58+
to_non_extruded_broadcasted(Base.tail(args))...,
59+
)
60+
@inline to_non_extruded_broadcasted(args::Tuple{Any}) =
61+
(to_non_extruded_broadcasted(args[1]),)
62+
@inline to_non_extruded_broadcasted(args::Tuple{}) = ()
63+
64+
@inline _checkbounds(bc, _, I) = nothing # TODO: fix this case
65+
@inline _checkbounds(bc, ::Tuple, I) = Base.checkbounds(bc, I)
66+
@inline function Base.getindex(
67+
bc::NonExtrudedBroadcasted,
68+
I::Union{Integer, CartesianIndex},
69+
)
70+
@boundscheck _checkbounds(bc, axes(bc), I) # is this really the only issue?
71+
@inbounds _broadcast_getindex(bc, I)
72+
end
73+
74+
# --- here, we define our own bounds checks
75+
@inline function Base.checkbounds(bc::NonExtrudedBroadcasted, I::Integer)
76+
# Base.checkbounds_indices(Bool, axes(bc), (I,)) || Base.throw_boundserror(bc, (I,)) # from Base
77+
Base.checkbounds_indices(Bool, (Base.OneTo(n_dofs(bc)),), (I,)) || Base.throw_boundserror(bc, (I,))
78+
end
79+
80+
import StaticArrays
81+
to_tuple(t::Tuple) = t
82+
to_tuple(t::NTuple{N, <: Base.OneTo}) where {N} = map(x->x.stop, t)
83+
to_tuple(t::NTuple{N, <: StaticArrays.SOneTo}) where {N} = map(x->x.stop, t)
84+
n_dofs(bc::NonExtrudedBroadcasted) = prod(to_tuple(axes(bc)))
85+
# ---
86+
87+
Base.@propagate_inbounds _broadcast_getindex(
88+
A::Union{Ref, AbstractArray{<:Any, 0}, Number},
89+
I::Integer,
90+
) = A[] # Scalar-likes can just ignore all indices
91+
Base.@propagate_inbounds _broadcast_getindex(
92+
::Ref{Type{T}},
93+
I::Integer,
94+
) where {T} = T
95+
# Tuples are statically known to be singleton or vector-like
96+
Base.@propagate_inbounds _broadcast_getindex(A::Tuple{Any}, I::Integer) = A[1]
97+
Base.@propagate_inbounds _broadcast_getindex(A::Tuple, I::Integer) = A[I[1]]
98+
# Everything else falls back to dynamically dropping broadcasted indices based upon its axes
99+
# Base.@propagate_inbounds _broadcast_getindex(A, I) = A[newindex(A, I)]
100+
Base.@propagate_inbounds _broadcast_getindex(A, I::Integer) = A[I]
101+
Base.@propagate_inbounds function _broadcast_getindex(
102+
bc::NonExtrudedBroadcasted{<:Any, <:Any, <:Any, <:Any},
103+
I::Integer,
104+
)
105+
args = _getindex(bc.args, I)
106+
return _broadcast_getindex_evalf(bc.f, args...)
107+
end
108+
@inline _broadcast_getindex_evalf(f::Tf, args::Vararg{Any, N}) where {Tf, N} =
109+
f(args...) # not propagate_inbounds
110+
Base.@propagate_inbounds _getindex(args::Tuple, I) =
111+
(_broadcast_getindex(args[1], I), _getindex(Base.tail(args), I)...)
112+
Base.@propagate_inbounds _getindex(args::Tuple{Any}, I) =
113+
(_broadcast_getindex(args[1], I),)
114+
Base.@propagate_inbounds _getindex(args::Tuple{}, I) = ()
115+
116+
@inline Base.axes(bc::NonExtrudedBroadcasted) = _axes(bc, bc.axes)
117+
_axes(::NonExtrudedBroadcasted, axes::Tuple) = axes
118+
@inline _axes(bc::NonExtrudedBroadcasted, ::Nothing) = Base.Broadcast.combine_axes(bc.args...)
119+
_axes(bc::NonExtrudedBroadcasted{<:Base.Broadcast.AbstractArrayStyle{0}}, ::Nothing) = ()
120+
@inline Base.axes(bc::NonExtrudedBroadcasted{<:Any, <:NTuple{N}}, d::Integer) where {N} =
121+
d <= N ? axes(bc)[d] : OneTo(1)
122+
Base.IndexStyle(::Type{<:NonExtrudedBroadcasted{<:Any, <:Tuple{Any}}}) = IndexLinear()
123+
@inline _axes(::NonExtrudedBroadcasted, axes) = axes
124+
@inline Base.eltype(bc::NonExtrudedBroadcasted) = Base.Broadcast.combine_axes(bc.args...)
125+
126+
127+
# ============================================================
128+
129+
#! format: on
130+
# Datalayouts
131+
@propagate_inbounds function linear_getindex(
132+
data::AbstractData{S},
133+
I::Integer,
134+
) where {S}
135+
s_array = farray_size(data)
136+
ss = StaticSize(s_array, field_dim(data))
137+
@inbounds get_struct_linear(parent(data), S, Val(field_dim(data)), I, ss)
138+
end
139+
@propagate_inbounds function linear_setindex!(
140+
data::AbstractData{S},
141+
val,
142+
I::Integer,
143+
) where {S}
144+
s_array = farray_size(data)
145+
ss = StaticSize(s_array, field_dim(data))
146+
@inbounds set_struct_linear!(
147+
parent(data),
148+
convert(S, val),
149+
Val(field_dim(data)),
150+
I,
151+
ss,
152+
)
153+
end
154+
155+
for DL in (:IJKFVH, :IJFH, :IFH, :IJF, :IF, :VF, :VIJFH, :VIFH) # Skip DataF, since we want that to MethodError.
156+
@eval @propagate_inbounds Base.getindex(data::$(DL), I::Integer) =
157+
linear_getindex(data, I)
158+
@eval @propagate_inbounds Base.setindex!(data::$(DL), val, I::Integer) =
159+
linear_setindex!(data, val, I)
160+
end

0 commit comments

Comments
 (0)