@@ -72,56 +72,83 @@ function knl_fused_copyto_linear!(fmbc::FusedMultiBroadcast, us)
72
72
end
73
73
return nothing
74
74
end
75
-
75
+ import MultiBroadcastFusion
76
+ const MBFCUDA =
77
+ Base. get_extension (MultiBroadcastFusion, :MultiBroadcastFusionCUDAExt )
76
78
# https://github.com/JuliaLang/julia/issues/56295
77
79
# Julia 1.11's Base.Broadcast currently requires
78
80
# multiple integer indexing, wheras Julia 1.10 did not.
79
81
# This means that we cannot reserve linear indexing to
80
82
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
81
83
# (including the GPU-variant related issue resolution efforts:
82
84
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).
85
+
86
+ function fused_multibroadcast_args (fmb:: FusedMultiBroadcast )
87
+ dest = first (fmb. pairs). first
88
+ us = DataLayouts. UniversalSize (dest)
89
+ return (fmb, us)
90
+ end
91
+
92
+ import MultiBroadcastFusion
83
93
function fused_copyto! (
84
- fmbc :: FusedMultiBroadcast ,
94
+ fmb :: FusedMultiBroadcast ,
85
95
dest1:: DataLayouts.AbstractData ,
86
96
:: ToCUDA ,
87
97
)
88
98
(_, _, Nv, _, Nh) = DataLayouts. universal_size (dest1)
89
- if Nv > 0 && Nh > 0
90
- bcs = map (p -> p. second, fmbc. pairs)
91
- destinations = map (p -> p. first, fmbc. pairs)
92
- if all (bc -> DataLayouts. has_uniform_datalayouts (bc), bcs) &&
93
- all (d -> d isa DataLayouts. EndsWithField, destinations) &&
94
- ! (VERSION ≥ v " 1.11.0-beta" )
95
- pairs′ = map (fmbc. pairs) do p
96
- bc′ = DataLayouts. to_non_extruded_broadcasted (p. second)
97
- Pair (p. first, Base. Broadcast. instantiate (bc′))
98
- end
99
- us = DataLayouts. UniversalSize (dest1)
100
- fmbc′ = FusedMultiBroadcast (pairs′)
101
- args = (fmbc′, us)
102
- threads = threads_via_occupancy (knl_fused_copyto_linear!, args)
103
- n_max_threads = min (threads, get_N (us))
104
- p = linear_partition (prod (size (dest1)), n_max_threads)
105
- auto_launch! (
106
- knl_fused_copyto_linear!,
107
- args;
108
- threads_s = p. threads,
109
- blocks_s = p. blocks,
110
- always_inline = false ,
111
- )
112
- else
113
- us = DataLayouts. UniversalSize (dest1)
114
- args = (fmbc, dest1, us)
115
- threads = threads_via_occupancy (knl_fused_copyto!, args)
116
- n_max_threads = min (threads, get_N (us))
117
- p = partition (dest1, n_max_threads)
118
- auto_launch! (
119
- knl_fused_copyto!,
120
- args;
121
- threads_s = p. threads,
122
- blocks_s = p. blocks,
123
- )
99
+ (Nv > 0 && Nh > 0 ) || return nothing # short circuit
100
+
101
+ if pkgversion (MultiBroadcastFusion) >= v " 0.3.3"
102
+ # Automatically split kernels by available parameter memory space:
103
+ fmbs = MBFCUDA. partition_kernels (
104
+ fmb,
105
+ FusedMultiBroadcast,
106
+ fused_multibroadcast_args,
107
+ )
108
+ for fmb in fmbs
109
+ launch_fused_copyto! (fmb)
110
+ end
111
+ else
112
+ launch_fused_copyto! (fmb)
113
+ end
114
+ return nothing
115
+ end
116
+
117
+ function launch_fused_copyto! (fmb:: FusedMultiBroadcast )
118
+ dest1 = first (fmb. pairs). first
119
+ us = DataLayouts. UniversalSize (dest1)
120
+ destinations = map (p -> p. first, fmb. pairs)
121
+ bcs = map (p -> p. second, fmb. pairs)
122
+ if all (bc -> DataLayouts. has_uniform_datalayouts (bc), bcs) &&
123
+ all (d -> d isa DataLayouts. EndsWithField, destinations) &&
124
+ ! (VERSION ≥ v " 1.11.0-beta" )
125
+ pairs′ = map (fmb. pairs) do p
126
+ bc′ = DataLayouts. to_non_extruded_broadcasted (p. second)
127
+ Pair (p. first, Base. Broadcast. instantiate (bc′))
124
128
end
129
+ fmb′ = FusedMultiBroadcast (pairs′)
130
+ args = (fmb′, us)
131
+ threads = threads_via_occupancy (knl_fused_copyto_linear!, args)
132
+ n_max_threads = min (threads, get_N (us))
133
+ p = linear_partition (prod (size (dest1)), n_max_threads)
134
+ auto_launch! (
135
+ knl_fused_copyto_linear!,
136
+ args;
137
+ threads_s = p. threads,
138
+ blocks_s = p. blocks,
139
+ always_inline = false ,
140
+ )
141
+ else
142
+ args = (fmb, dest1, us)
143
+ threads = threads_via_occupancy (knl_fused_copyto!, args)
144
+ n_max_threads = min (threads, get_N (us))
145
+ p = partition (dest1, n_max_threads)
146
+ auto_launch! (
147
+ knl_fused_copyto!,
148
+ args;
149
+ threads_s = p. threads,
150
+ blocks_s = p. blocks,
151
+ )
125
152
end
126
153
return nothing
127
154
end
0 commit comments