Skip to content

Commit 99132e0

Browse files
Merge pull request #2075 from CliMA/ck/auto_kernel_splitting
Automatically split fused kernels by parameter memory limits
2 parents d0a9f9d + 5b200f5 commit 99132e0

File tree

2 files changed

+65
-38
lines changed

2 files changed

+65
-38
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ LinearAlgebra = "1"
6767
LazyBroadcast = "0.1"
6868
Logging = "1"
6969
MPI = "0.20"
70-
MultiBroadcastFusion = "0.3"
70+
MultiBroadcastFusion = "0.3, 0.4"
7171
NVTX = "0.3"
7272
OrderedCollections = "1"
7373
PkgVersion = "0.1, 0.2, 0.3"

ext/cuda/data_layouts_fused_copyto.jl

+64-37
Original file line numberDiff line numberDiff line change
@@ -72,56 +72,83 @@ function knl_fused_copyto_linear!(fmbc::FusedMultiBroadcast, us)
7272
end
7373
return nothing
7474
end
75-
75+
import MultiBroadcastFusion
76+
const MBFCUDA =
77+
Base.get_extension(MultiBroadcastFusion, :MultiBroadcastFusionCUDAExt)
7678
# https://github.com/JuliaLang/julia/issues/56295
7779
# Julia 1.11's Base.Broadcast currently requires
7880
# multiple integer indexing, wheras Julia 1.10 did not.
7981
# This means that we cannot reserve linear indexing to
8082
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
8183
# (including the GPU-variant related issue resolution efforts:
8284
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).
85+
86+
function fused_multibroadcast_args(fmb::FusedMultiBroadcast)
87+
dest = first(fmb.pairs).first
88+
us = DataLayouts.UniversalSize(dest)
89+
return (fmb, us)
90+
end
91+
92+
import MultiBroadcastFusion
8393
function fused_copyto!(
84-
fmbc::FusedMultiBroadcast,
94+
fmb::FusedMultiBroadcast,
8595
dest1::DataLayouts.AbstractData,
8696
::ToCUDA,
8797
)
8898
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest1)
89-
if Nv > 0 && Nh > 0
90-
bcs = map(p -> p.second, fmbc.pairs)
91-
destinations = map(p -> p.first, fmbc.pairs)
92-
if all(bc -> DataLayouts.has_uniform_datalayouts(bc), bcs) &&
93-
all(d -> d isa DataLayouts.EndsWithField, destinations) &&
94-
!(VERSION v"1.11.0-beta")
95-
pairs′ = map(fmbc.pairs) do p
96-
bc′ = DataLayouts.to_non_extruded_broadcasted(p.second)
97-
Pair(p.first, Base.Broadcast.instantiate(bc′))
98-
end
99-
us = DataLayouts.UniversalSize(dest1)
100-
fmbc′ = FusedMultiBroadcast(pairs′)
101-
args = (fmbc′, us)
102-
threads = threads_via_occupancy(knl_fused_copyto_linear!, args)
103-
n_max_threads = min(threads, get_N(us))
104-
p = linear_partition(prod(size(dest1)), n_max_threads)
105-
auto_launch!(
106-
knl_fused_copyto_linear!,
107-
args;
108-
threads_s = p.threads,
109-
blocks_s = p.blocks,
110-
always_inline = false,
111-
)
112-
else
113-
us = DataLayouts.UniversalSize(dest1)
114-
args = (fmbc, dest1, us)
115-
threads = threads_via_occupancy(knl_fused_copyto!, args)
116-
n_max_threads = min(threads, get_N(us))
117-
p = partition(dest1, n_max_threads)
118-
auto_launch!(
119-
knl_fused_copyto!,
120-
args;
121-
threads_s = p.threads,
122-
blocks_s = p.blocks,
123-
)
99+
(Nv > 0 && Nh > 0) || return nothing # short circuit
100+
101+
if pkgversion(MultiBroadcastFusion) >= v"0.3.3"
102+
# Automatically split kernels by available parameter memory space:
103+
fmbs = MBFCUDA.partition_kernels(
104+
fmb,
105+
FusedMultiBroadcast,
106+
fused_multibroadcast_args,
107+
)
108+
for fmb in fmbs
109+
launch_fused_copyto!(fmb)
110+
end
111+
else
112+
launch_fused_copyto!(fmb)
113+
end
114+
return nothing
115+
end
116+
117+
function launch_fused_copyto!(fmb::FusedMultiBroadcast)
118+
dest1 = first(fmb.pairs).first
119+
us = DataLayouts.UniversalSize(dest1)
120+
destinations = map(p -> p.first, fmb.pairs)
121+
bcs = map(p -> p.second, fmb.pairs)
122+
if all(bc -> DataLayouts.has_uniform_datalayouts(bc), bcs) &&
123+
all(d -> d isa DataLayouts.EndsWithField, destinations) &&
124+
!(VERSION v"1.11.0-beta")
125+
pairs′ = map(fmb.pairs) do p
126+
bc′ = DataLayouts.to_non_extruded_broadcasted(p.second)
127+
Pair(p.first, Base.Broadcast.instantiate(bc′))
124128
end
129+
fmb′ = FusedMultiBroadcast(pairs′)
130+
args = (fmb′, us)
131+
threads = threads_via_occupancy(knl_fused_copyto_linear!, args)
132+
n_max_threads = min(threads, get_N(us))
133+
p = linear_partition(prod(size(dest1)), n_max_threads)
134+
auto_launch!(
135+
knl_fused_copyto_linear!,
136+
args;
137+
threads_s = p.threads,
138+
blocks_s = p.blocks,
139+
always_inline = false,
140+
)
141+
else
142+
args = (fmb, dest1, us)
143+
threads = threads_via_occupancy(knl_fused_copyto!, args)
144+
n_max_threads = min(threads, get_N(us))
145+
p = partition(dest1, n_max_threads)
146+
auto_launch!(
147+
knl_fused_copyto!,
148+
args;
149+
threads_s = p.threads,
150+
blocks_s = p.blocks,
151+
)
125152
end
126153
return nothing
127154
end

0 commit comments

Comments
 (0)