Skip to content

Commit 9d5f538

Browse files
Add linear index support for pointwise kernels
1 parent ee2b83e commit 9d5f538

15 files changed

+750
-151
lines changed

ext/cuda/data_layouts.jl

+13
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,16 @@ function Adapt.adapt_structure(
5353
end,
5454
)
5555
end
56+
57+
import Adapt
58+
import CUDA
59+
function Adapt.adapt_structure(
60+
to::CUDA.KernelAdaptor,
61+
bc::DataLayouts.NonExtrudedBroadcasted{Style},
62+
) where {Style}
63+
DataLayouts.NonExtrudedBroadcasted{Style}(
64+
adapt_f(to, bc.f),
65+
Adapt.adapt(to, bc.args),
66+
Adapt.adapt(to, bc.axes),
67+
)
68+
end

ext/cuda/data_layouts_copyto.jl

+23-90
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,9 @@
1+
import ClimaCore.DataLayouts:
2+
to_non_extruded_broadcasted, has_uniform_datalayouts
13
DataLayouts._device_dispatch(x::CUDA.CuArray) = ToCUDA()
24

3-
function knl_copyto!(dest, src)
4-
5-
i = CUDA.threadIdx().x
6-
j = CUDA.threadIdx().y
7-
8-
h = CUDA.blockIdx().x
9-
v = CUDA.blockDim().z * (CUDA.blockIdx().y - 1) + CUDA.threadIdx().z
10-
11-
if v <= size(dest, 4)
12-
I = CartesianIndex((i, j, 1, v, h))
13-
@inbounds dest[I] = src[I]
14-
end
15-
return nothing
16-
end
17-
18-
function Base.copyto!(
19-
dest::IJFH{S, Nij, Nh},
20-
bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh},
21-
::ToCUDA,
22-
) where {S, Nij, Nh}
23-
if Nh > 0
24-
auto_launch!(
25-
knl_copyto!,
26-
(dest, bc),
27-
dest;
28-
threads_s = (Nij, Nij),
29-
blocks_s = (Nh, 1),
30-
)
31-
end
32-
return dest
33-
end
34-
35-
function Base.copyto!(
36-
dest::VIJFH{S, Nv, Nij, Nh},
37-
bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh},
38-
::ToCUDA,
39-
) where {S, Nv, Nij, Nh}
40-
if Nv > 0 && Nh > 0
41-
Nv_per_block = min(Nv, fld(256, Nij * Nij))
42-
Nv_blocks = cld(Nv, Nv_per_block)
43-
auto_launch!(
44-
knl_copyto!,
45-
(dest, bc),
46-
dest;
47-
threads_s = (Nij, Nij, Nv_per_block),
48-
blocks_s = (Nh, Nv_blocks),
49-
)
50-
end
51-
return dest
52-
end
53-
54-
function Base.copyto!(
55-
dest::VF{S, Nv},
56-
bc::DataLayouts.BroadcastedUnionVF{S, Nv},
57-
::ToCUDA,
58-
) where {S, Nv}
59-
if Nv > 0
60-
auto_launch!(
61-
knl_copyto!,
62-
(dest, bc),
63-
dest;
64-
threads_s = (1, 1),
65-
blocks_s = (1, Nv),
66-
)
67-
end
68-
return dest
69-
end
70-
71-
function Base.copyto!(
72-
dest::DataF{S},
73-
bc::DataLayouts.BroadcastedUnionDataF{S},
74-
::ToCUDA,
75-
) where {S}
76-
auto_launch!(
77-
knl_copyto!,
78-
(dest, bc),
79-
dest;
80-
threads_s = (1, 1),
81-
blocks_s = (1, 1),
82-
)
83-
return dest
84-
end
85-
865
import ClimaCore.DataLayouts: isascalar
87-
function knl_copyto_flat!(dest::AbstractData, bc, us)
6+
function knl_copyto_cart!(dest::AbstractData, bc, us)
887
@inbounds begin
898
tidx = thread_index()
909
if tidx get_N(us)
@@ -96,24 +15,38 @@ function knl_copyto_flat!(dest::AbstractData, bc, us)
9615
return nothing
9716
end
9817

18+
function knl_copyto_linear!(dest::AbstractData, bc, us)
19+
@inbounds begin
20+
tidx = thread_index()
21+
if tidx get_N(us)
22+
dest[tidx] = bc[tidx]
23+
end
24+
end
25+
return nothing
26+
end
27+
9928
function cuda_copyto!(dest::AbstractData, bc)
10029
(_, _, Nv, Nh) = DataLayouts.universal_size(dest)
30+
(Nv > 0 && Nh > 0) || return dest
10131
us = DataLayouts.UniversalSize(dest)
102-
if Nv > 0 && Nh > 0
103-
auto_launch!(knl_copyto_flat!, (dest, bc, us), dest; auto = true)
32+
if has_uniform_datalayouts(bc)
33+
bc′ = to_non_extruded_broadcasted(bc)
34+
auto_launch!(knl_copyto_linear!, (dest, bc′, us), dest; auto = true)
35+
else
36+
auto_launch!(knl_copyto_cart!, (dest, bc, us), dest; auto = true)
10437
end
10538
return dest
10639
end
10740

10841
# TODO: can we use CUDA's luanch configuration for all data layouts?
10942
# Currently, it seems to have a slight performance degradation.
11043
#! format: off
111-
# Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc)
44+
Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc)
11245
Base.copyto!(dest::IFH{S, Ni, Nh}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni, Nh}, ::ToCUDA) where {S, Ni, Nh} = cuda_copyto!(dest, bc)
11346
Base.copyto!(dest::IJF{S, Nij}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij}, ::ToCUDA) where {S, Nij} = cuda_copyto!(dest, bc)
11447
Base.copyto!(dest::IF{S, Ni}, bc::DataLayouts.BroadcastedUnionIF{S, Ni}, ::ToCUDA) where {S, Ni} = cuda_copyto!(dest, bc)
11548
Base.copyto!(dest::VIFH{S, Nv, Ni, Nh}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni, Nh}, ::ToCUDA) where {S, Nv, Ni, Nh} = cuda_copyto!(dest, bc)
116-
# Base.copyto!(dest::VIJFH{S, Nv, Nij, Nh}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, ::ToCUDA) where {S, Nv, Nij, Nh} = cuda_copyto!(dest, bc)
117-
# Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
118-
# Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
49+
Base.copyto!(dest::VIJFH{S, Nv, Nij, Nh}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, ::ToCUDA) where {S, Nv, Nij, Nh} = cuda_copyto!(dest, bc)
50+
Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
51+
Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
11952
#! format: on

ext/cuda/data_layouts_fill.jl

+1-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ function knl_fill_flat!(dest::AbstractData, val, us)
22
@inbounds begin
33
tidx = thread_index()
44
if tidx get_N(us)
5-
n = size(dest)
6-
I = kernel_indexes(tidx, n)
7-
@inbounds dest[I] = val
5+
@inbounds dest[tidx] = val
86
end
97
end
108
return nothing

src/DataLayouts/DataLayouts.jl

+33
Original file line numberDiff line numberDiff line change
@@ -1523,6 +1523,37 @@ get_Nij(::IF{S, Nij}) where {S, Nij} = Nij
15231523
@inline field_dim(::VIJFH) = 4
15241524
@inline field_dim(::VIFH) = 3
15251525

1526+
# Returns the size of the backing array.
1527+
@inline array_size(::IJKFVH{S, Nij, Nk, Nv, Nh}) where {S, Nij, Nk, Nv, Nh} =
1528+
(Nij, Nij, Nk, 1, Nv, Nh)
1529+
@inline array_size(::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = (Nij, Nij, 1, Nh)
1530+
@inline array_size(::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, 1, Nh)
1531+
@inline array_size(::DataF{S}) where {S} = (1,)
1532+
@inline array_size(::IJF{S, Nij}) where {S, Nij} = (Nij, Nij, 1)
1533+
@inline array_size(::IF{S, Ni}) where {S, Ni} = (Ni, 1)
1534+
@inline array_size(::VF{S, Nv}) where {S, Nv} = (Nv, 1)
1535+
@inline array_size(::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} =
1536+
(Nv, Nij, Nij, 1, Nh)
1537+
@inline array_size(::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} =
1538+
(Nv, Ni, 1, Nh)
1539+
1540+
@inline farray_size(
1541+
data::IJKFVH{S, Nij, Nk, Nv, Nh},
1542+
) where {S, Nij, Nk, Nv, Nh} = (Nij, Nij, Nk, ncomponents(data), Nv, Nh)
1543+
@inline farray_size(data::IJFH{S, Nij, Nh}) where {S, Nij, Nh} =
1544+
(Nij, Nij, ncomponents(data), Nh)
1545+
@inline farray_size(data::IFH{S, Ni, Nh}) where {S, Ni, Nh} =
1546+
(Ni, ncomponents(data), Nh)
1547+
@inline farray_size(data::DataF{S}) where {S} = (ncomponents(data),)
1548+
@inline farray_size(data::IJF{S, Nij}) where {S, Nij} =
1549+
(Nij, Nij, ncomponents(data))
1550+
@inline farray_size(data::IF{S, Ni}) where {S, Ni} = (Ni, ncomponents(data))
1551+
@inline farray_size(data::VF{S, Nv}) where {S, Nv} = (Nv, ncomponents(data))
1552+
@inline farray_size(data::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} =
1553+
(Nv, Nij, Nij, ncomponents(data), Nh)
1554+
@inline farray_size(data::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} =
1555+
(Nv, Ni, ncomponents(data), Nh)
1556+
15261557
#! format: off
15271558
@inline to_data_specific(::IJFH, I::CartesianIndex) = CartesianIndex(I.I[1], I.I[2], 1, I.I[5])
15281559
@inline to_data_specific(::IFH, I::CartesianIndex) = CartesianIndex(I.I[1], 1, I.I[5])
@@ -1600,10 +1631,12 @@ _device_dispatch(x::AbstractData) = _device_dispatch(parent(x))
16001631
_device_dispatch(x::SArray) = ToCPU()
16011632
_device_dispatch(x::MArray) = ToCPU()
16021633

1634+
include("non_extruded_broadcasted.jl")
16031635
include("copyto.jl")
16041636
include("fused_copyto.jl")
16051637
include("fill.jl")
16061638
include("mapreduce.jl")
1639+
include("has_uniform_datalayouts.jl")
16071640

16081641
slab_index(i, j) = CartesianIndex(i, j, 1, 1, 1)
16091642
slab_index(i) = CartesianIndex(i, 1, 1, 1, 1)

src/DataLayouts/broadcast.jl

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ DataSlab2DStyle(::Type{VIJFHStyle{Nv, Nij, Nh, A}}) where {Nv, Nij, Nh, A} =
7373
#####
7474

7575
#! format: off
76+
const BroadcastedUnionData = Union{Base.Broadcast.Broadcasted{<:DataStyle}, AbstractData}
7677
const BroadcastedUnionIJFH{S, Nij, Nh, A} = Union{Base.Broadcast.Broadcasted{IJFHStyle{Nij, Nh, A}}, IJFH{S, Nij, Nh, A}}
7778
const BroadcastedUnionIFH{S, Ni, Nh, A} = Union{Base.Broadcast.Broadcasted{IFHStyle{Ni, Nh, A}}, IFH{S, Ni, Nh, A}}
7879
const BroadcastedUnionIJF{S, Nij, A} = Union{Base.Broadcast.Broadcasted{IJFStyle{Nij, A}}, IJF{S, Nij, A}}

src/DataLayouts/copyto.jl

+15-3
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,22 @@
22
##### Dispatching and edge cases
33
#####
44

5-
Base.copyto!(
6-
dest::AbstractData,
5+
function Base.copyto!(
6+
dest::AbstractData{S},
77
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
8-
) = Base.copyto!(dest, bc, device_dispatch(dest))
8+
) where {S}
9+
dev = device_dispatch(dest)
10+
if dev isa ToCPU && has_uniform_datalayouts(bc)
11+
# Specialize on linear indexing case:
12+
bc′ = Base.Broadcast.instantiate(to_non_extruded_broadcasted(bc))
13+
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
14+
dest[I] = convert(S, bc′[I])
15+
end
16+
else
17+
Base.copyto!(dest, bc, device_dispatch(dest))
18+
end
19+
return dest
20+
end
921

1022
# Specialize on non-Broadcasted objects
1123
function Base.copyto!(dest::D, src::D) where {D <: AbstractData}

src/DataLayouts/fill.jl

+7-54
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,13 @@
1-
function Base.fill!(data::IJFH, val, ::ToCPU)
2-
(_, _, _, _, Nh) = size(data)
3-
@inbounds for h in 1:Nh
4-
fill!(slab(data, h), val)
1+
function Base.fill!(dest::AbstractData, val, ::ToCPU)
2+
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
3+
dest[I] = val
54
end
6-
return data
5+
return dest
76
end
87

9-
function Base.fill!(data::IFH, val, ::ToCPU)
10-
(_, _, _, _, Nh) = size(data)
11-
@inbounds for h in 1:Nh
12-
fill!(slab(data, h), val)
13-
end
14-
return data
15-
end
16-
17-
function Base.fill!(data::DataF, val, ::ToCPU)
18-
@inbounds data[] = val
19-
return data
20-
end
21-
22-
function Base.fill!(data::IJF{S, Nij}, val, ::ToCPU) where {S, Nij}
23-
@inbounds for j in 1:Nij, i in 1:Nij
24-
data[CartesianIndex(i, j, 1, 1, 1)] = val
25-
end
26-
return data
27-
end
28-
29-
function Base.fill!(data::IF{S, Ni}, val, ::ToCPU) where {S, Ni}
30-
@inbounds for i in 1:Ni
31-
data[CartesianIndex(i, 1, 1, 1, 1)] = val
32-
end
33-
return data
34-
end
35-
36-
function Base.fill!(data::VF, val, ::ToCPU)
37-
Nv = nlevels(data)
38-
@inbounds for v in 1:Nv
39-
data[CartesianIndex(1, 1, 1, v, 1)] = val
40-
end
41-
return data
42-
end
43-
44-
function Base.fill!(data::VIJFH, val, ::ToCPU)
45-
(Ni, Nj, _, Nv, Nh) = size(data)
46-
@inbounds for h in 1:Nh, v in 1:Nv
47-
fill!(slab(data, v, h), val)
48-
end
49-
return data
50-
end
51-
52-
function Base.fill!(data::VIFH, val, ::ToCPU)
53-
(Ni, _, _, Nv, Nh) = size(data)
54-
@inbounds for h in 1:Nh, v in 1:Nv
55-
fill!(slab(data, v, h), val)
56-
end
57-
return data
8+
function Base.fill!(dest::DataF, val, ::ToCPU)
9+
@inbounds dest[] = val
10+
return dest
5811
end
5912

6013
Base.fill!(dest::AbstractData, val) =
+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
@inline function first_datalayout_in_bc(args::Tuple, rargs...)
2+
x1 = first_datalayout_in_bc(args[1], rargs...)
3+
x1 isa AbstractData && return x1
4+
return first_datalayout_in_bc(Base.tail(args), rargs...)
5+
end
6+
7+
@inline first_datalayout_in_bc(args::Tuple{Any}, rargs...) =
8+
first_datalayout_in_bc(args[1], rargs...)
9+
@inline first_datalayout_in_bc(args::Tuple{}, rargs...) = nothing
10+
@inline first_datalayout_in_bc(x) = nothing
11+
@inline first_datalayout_in_bc(x::AbstractData) = x
12+
13+
@inline first_datalayout_in_bc(bc::Base.Broadcast.Broadcasted) =
14+
first_datalayout_in_bc(bc.args)
15+
16+
@inline _has_uniform_datalayouts_args(truesofar, start, args::Tuple, rargs...) =
17+
truesofar &&
18+
_has_uniform_datalayouts(truesofar, start, args[1], rargs...) &&
19+
_has_uniform_datalayouts_args(truesofar, start, Base.tail(args), rargs...)
20+
21+
@inline _has_uniform_datalayouts_args(
22+
truesofar,
23+
start,
24+
args::Tuple{Any},
25+
rargs...,
26+
) = truesofar && _has_uniform_datalayouts(truesofar, start, args[1], rargs...)
27+
@inline _has_uniform_datalayouts_args(truesofar, _, args::Tuple{}, rargs...) =
28+
truesofar
29+
30+
@inline function _has_uniform_datalayouts(
31+
truesofar,
32+
start,
33+
bc::Base.Broadcast.Broadcasted,
34+
)
35+
return truesofar && _has_uniform_datalayouts_args(truesofar, start, bc.args)
36+
end
37+
for DL in (:IJKFVH, :IJFH, :IFH, :DataF, :IJF, :IF, :VF, :VIJFH, :VIFH)
38+
@eval begin
39+
@inline _has_uniform_datalayouts(truesofar, ::$(DL), ::$(DL)) = true
40+
end
41+
end
42+
@inline _has_uniform_datalayouts(truesofar, _, x::AbstractData) = false
43+
@inline _has_uniform_datalayouts(truesofar, _, x) = truesofar
44+
45+
"""
46+
has_uniform_datalayouts
47+
Find the first datalayout in the broadcast expression (BCE),
48+
and compares against every other datalayout in the BCE. Returns
49+
- `true` if the broadcasted object has only a single kind of datalayout (e.g. VF,VF, VIJFH,VIJFH)
50+
- `false` if the broadcasted object has multiple kinds of datalayouts (e.g. VIJFH, VIFH)
51+
Note: a broadcasted object can have different _types_,
52+
e.g., `VIFJH{Float64}` and `VIFJH{Tuple{Float64,Float64}}`
53+
but not different kinds, e.g., `VIFJH{Float64}` and `VF{Float64}`.
54+
"""
55+
function has_uniform_datalayouts end
56+
57+
@inline has_uniform_datalayouts(bc::Base.Broadcast.Broadcasted) =
58+
_has_uniform_datalayouts_args(true, first_datalayout_in_bc(bc), bc.args)
59+
60+
@inline has_uniform_datalayouts(bc::AbstractData) = true

0 commit comments

Comments
 (0)