Skip to content

Commit 3438e2c

Browse files
Reduce number of specialized methods
1 parent 193ecfa commit 3438e2c

File tree

6 files changed

+44
-46
lines changed

6 files changed

+44
-46
lines changed

ext/cuda/data_layouts_copyto.jl

+20-12
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
DataLayouts._device_dispatch(x::CUDA.CuArray) = ToCUDA()
1+
DataLayouts.device_dispatch(x::CUDA.CuArray) = ToCUDA()
22

33
function knl_copyto!(dest, src, us)
44
I = universal_index(dest)
@@ -8,7 +8,7 @@ function knl_copyto!(dest, src, us)
88
return nothing
99
end
1010

11-
function cuda_copyto!(dest::AbstractData, bc)
11+
function Base.copyto!(dest::AbstractData, bc, ::ToCUDA)
1212
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
1313
us = DataLayouts.UniversalSize(dest)
1414
if Nv > 0 && Nh > 0
@@ -26,13 +26,21 @@ function cuda_copyto!(dest::AbstractData, bc)
2626
return dest
2727
end
2828

29-
#! format: off
30-
Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij}, ::ToCUDA) where {S, Nij} = cuda_copyto!(dest, bc)
31-
Base.copyto!(dest::IFH{S, Ni}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni}, ::ToCUDA) where {S, Ni} = cuda_copyto!(dest, bc)
32-
Base.copyto!(dest::IJF{S, Nij}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij}, ::ToCUDA) where {S, Nij} = cuda_copyto!(dest, bc)
33-
Base.copyto!(dest::IF{S, Ni}, bc::DataLayouts.BroadcastedUnionIF{S, Ni}, ::ToCUDA) where {S, Ni} = cuda_copyto!(dest, bc)
34-
Base.copyto!(dest::VIFH{S, Nv, Ni}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni}, ::ToCUDA) where {S, Nv, Ni} = cuda_copyto!(dest, bc)
35-
Base.copyto!(dest::VIJFH{S, Nv, Nij}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij}, ::ToCUDA) where {S, Nv, Nij} = cuda_copyto!(dest, bc)
36-
Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
37-
Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
38-
#! format: on
29+
# broadcasting scalar assignment
30+
# Performance optimization for the common identity scalar case: dest .= val
31+
# And this is valid for the CPU or GPU, since the broadcasted object
32+
# is a scalar type.
33+
function Base.copyto!(
34+
dest::AbstractData,
35+
bc::Base.Broadcast.Broadcasted{Style},
36+
::ToCUDA,
37+
) where {
38+
Style <:
39+
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
40+
}
41+
bc = Base.Broadcast.instantiate(
42+
Base.Broadcast.Broadcasted{Style}(bc.f, bc.args, ()),
43+
)
44+
@inbounds bc0 = bc[]
45+
fill!(dest, bc0)
46+
end

ext/cuda/data_layouts_fill.jl

+1-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ function knl_fill!(dest, val, us)
66
return nothing
77
end
88

9-
function cuda_fill!(dest::AbstractData, bc)
9+
function Base.fill!(dest::AbstractData, bc, ::ToCUDA)
1010
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest)
1111
us = DataLayouts.UniversalSize(dest)
1212
if Nv > 0 && Nh > 0
@@ -23,5 +23,3 @@ function cuda_fill!(dest::AbstractData, bc)
2323
end
2424
return dest
2525
end
26-
27-
Base.fill!(dest::AbstractData, val, ::ToCUDA) = cuda_fill!(dest, val)

src/DataLayouts/DataLayouts.jl

+7-9
Original file line numberDiff line numberDiff line change
@@ -1524,19 +1524,17 @@ array2data(array::AbstractArray{T}, data::AbstractData) where {T} =
15241524
)
15251525

15261526
"""
1527-
device_dispatch(data::AbstractData)
1527+
device_dispatch(array::AbstractArray)
15281528
15291529
Returns an `ToCPU` or a `ToCUDA` for CPU
15301530
and CUDA-backed arrays accordingly.
15311531
"""
1532-
device_dispatch(dest::AbstractData) = _device_dispatch(dest)
1533-
1534-
_device_dispatch(x::Array) = ToCPU()
1535-
_device_dispatch(x::SubArray) = _device_dispatch(parent(x))
1536-
_device_dispatch(x::Base.ReshapedArray) = _device_dispatch(parent(x))
1537-
_device_dispatch(x::AbstractData) = _device_dispatch(parent(x))
1538-
_device_dispatch(x::SArray) = ToCPU()
1539-
_device_dispatch(x::MArray) = ToCPU()
1532+
device_dispatch(x::Array) = ToCPU()
1533+
device_dispatch(x::SubArray) = device_dispatch(parent(x))
1534+
device_dispatch(x::Base.ReshapedArray) = device_dispatch(parent(x))
1535+
device_dispatch(x::AbstractData) = device_dispatch(parent(x))
1536+
device_dispatch(x::SArray) = ToCPU()
1537+
device_dispatch(x::MArray) = ToCPU()
15401538

15411539
include("copyto.jl")
15421540
include("fused_copyto.jl")

src/DataLayouts/copyto.jl

+14-20
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
Base.copyto!(
66
dest::AbstractData,
7-
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
8-
) = Base.copyto!(dest, bc, device_dispatch(dest))
7+
@nospecialize(bc::Union{AbstractData, Base.Broadcast.Broadcasted}),
8+
) = Base.copyto!(dest, bc, device_dispatch(parent(dest)))
99

1010
# Specialize on non-Broadcasted objects
1111
function Base.copyto!(dest::D, src::D) where {D <: AbstractData}
@@ -15,8 +15,6 @@ end
1515

1616
# broadcasting scalar assignment
1717
# Performance optimization for the common identity scalar case: dest .= val
18-
# And this is valid for the CPU or GPU, since the broadcasted object
19-
# is a scalar type.
2018
function Base.copyto!(
2119
dest::AbstractData,
2220
bc::Base.Broadcast.Broadcasted{Style},
@@ -51,10 +49,9 @@ function Base.copyto!(
5149
::ToCPU,
5250
) where {S, Nij}
5351
(_, _, _, _, Nh) = size(dest)
54-
@inbounds for h in 1:Nh
55-
slab_dest = slab(dest, h)
56-
slab_bc = slab(bc, h)
57-
copyto!(slab_dest, slab_bc)
52+
@inbounds for h in 1:Nh, j in 1:Nij, i in 1:Nij
53+
idx = CartesianIndex(i, j, 1, 1, h)
54+
dest[idx] = convert(S, bc[idx])
5855
end
5956
return dest
6057
end
@@ -65,10 +62,9 @@ function Base.copyto!(
6562
::ToCPU,
6663
) where {S, Ni}
6764
(_, _, _, _, Nh) = size(dest)
68-
@inbounds for h in 1:Nh
69-
slab_dest = slab(dest, h)
70-
slab_bc = slab(bc, h)
71-
copyto!(slab_dest, slab_bc)
65+
@inbounds for h in 1:Nh, i in 1:Ni
66+
idx = CartesianIndex(i, 1, 1, 1, h)
67+
dest[idx] = convert(S, bc[idx])
7268
end
7369
return dest
7470
end
@@ -131,10 +127,9 @@ function Base.copyto!(
131127
) where {S, Nv, Ni}
132128
# copy contiguous columns
133129
(_, _, _, _, Nh) = size(dest)
134-
@inbounds for h in 1:Nh, i in 1:Ni
135-
col_dest = column(dest, i, h)
136-
col_bc = column(bc, i, h)
137-
copyto!(col_dest, col_bc)
130+
@inbounds for h in 1:Nh, i in 1:Ni, v in 1:Nv
131+
idx = CartesianIndex(i, 1, 1, v, h)
132+
dest[idx] = convert(S, bc[idx])
138133
end
139134
return dest
140135
end
@@ -146,10 +141,9 @@ function Base.copyto!(
146141
) where {S, Nv, Nij}
147142
# copy contiguous columns
148143
(_, _, _, _, Nh) = size(dest)
149-
@inbounds for h in 1:Nh, j in 1:Nij, i in 1:Nij
150-
col_dest = column(dest, i, j, h)
151-
col_bc = column(bc, i, j, h)
152-
copyto!(col_dest, col_bc)
144+
@inbounds for h in 1:Nh, j in 1:Nij, i in 1:Nij, v in 1:Nv
145+
idx = CartesianIndex(i, j, 1, v, h)
146+
dest[idx] = convert(S, bc[idx])
153147
end
154148
return dest
155149
end

src/DataLayouts/fill.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,4 @@ function Base.fill!(data::VIFH, val, ::ToCPU)
5858
end
5959

6060
Base.fill!(dest::AbstractData, val) =
61-
Base.fill!(dest, val, device_dispatch(dest))
61+
Base.fill!(dest, val, device_dispatch(parent(dest)))

src/DataLayouts/fused_copyto.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function Base.copyto!(
1818
end,
1919
)
2020
# check_fused_broadcast_axes(fmbc) # we should already have checked the axes
21-
fused_copyto!(fmb_inst, dest1, device_dispatch(dest1))
21+
fused_copyto!(fmb_inst, dest1, device_dispatch(parent(dest1)))
2222
end
2323

2424
function fused_copyto!(

0 commit comments

Comments
 (0)