-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathcopyto.jl
163 lines (150 loc) · 3.88 KB
/
copyto.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#####
##### Dispatching and edge cases
#####
function Base.copyto!(
dest::AbstractData{S},
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
) where {S}
dev = device_dispatch(dest)
if dev isa ToCPU && has_uniform_datalayouts(bc) && !(dest isa DataF)
# Specialize on linear indexing case:
bc′ = Base.Broadcast.instantiate(to_non_extruded_broadcasted(bc))
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
dest[I] = convert(S, bc′[I])
end
else
Base.copyto!(dest, bc, device_dispatch(dest))
end
return dest
end
# Specialize on non-Broadcasted objects
function Base.copyto!(dest::D, src::D) where {D <: AbstractData}
copyto!(parent(dest), parent(src))
return dest
end
# broadcasting scalar assignment
# Performance optimization for the common identity scalar case: dest .= val
# And this is valid for the CPU or GPU, since the broadcasted object
# is a scalar type.
function Base.copyto!(
dest::AbstractData,
bc::Base.Broadcast.Broadcasted{Style},
::AbstractDispatchToDevice,
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
}
bc = Base.Broadcast.instantiate(
Base.Broadcast.Broadcasted{Style}(bc.f, bc.args, ()),
)
@inbounds bc0 = bc[]
fill!(dest, bc0)
end
#####
##### DataLayouts
#####
function Base.copyto!(
dest::DataF{S},
bc::BroadcastedUnionDataF{S, A},
::ToCPU,
) where {S, A}
@inbounds dest[] = convert(S, bc[])
return dest
end
function Base.copyto!(
dest::IJFH{S, Nij},
bc::BroadcastedUnionIJFH{S, Nij, Nh},
::ToCPU,
) where {S, Nij, Nh}
@inbounds for h in 1:Nh
slab_dest = slab(dest, h)
slab_bc = slab(bc, h)
copyto!(slab_dest, slab_bc)
end
return dest
end
function Base.copyto!(
dest::IFH{S, Ni},
bc::BroadcastedUnionIFH{S, Ni, Nh},
::ToCPU,
) where {S, Ni, Nh}
@inbounds for h in 1:Nh
slab_dest = slab(dest, h)
slab_bc = slab(bc, h)
copyto!(slab_dest, slab_bc)
end
return dest
end
# inline inner slab(::DataSlab2D) copy
function Base.copyto!(
dest::IJF{S, Nij},
bc::BroadcastedUnionIJF{S, Nij, A},
::ToCPU,
) where {S, Nij, A}
@inbounds for j in 1:Nij, i in 1:Nij
idx = CartesianIndex(i, j, 1, 1, 1)
dest[idx] = convert(S, bc[idx])
end
return dest
end
function Base.copyto!(
dest::IF{S, Ni},
bc::BroadcastedUnionIF{S, Ni, A},
::ToCPU,
) where {S, Ni, A}
@inbounds for i in 1:Ni
idx = CartesianIndex(i, 1, 1, 1, 1)
dest[idx] = convert(S, bc[idx])
end
return dest
end
# inline inner slab(::DataSlab1D) copy
function Base.copyto!(
dest::IF{S, Ni},
bc::Base.Broadcast.Broadcasted{IFStyle{Ni, A}},
::ToCPU,
) where {S, Ni, A}
@inbounds for i in 1:Ni
idx = CartesianIndex(i, 1, 1, 1, 1)
dest[idx] = convert(S, bc[idx])
end
return dest
end
# inline inner column(::DataColumn) copy
function Base.copyto!(
dest::VF{S, Nv},
bc::BroadcastedUnionVF{S, Nv, A},
::ToCPU,
) where {S, Nv, A}
@inbounds for v in 1:Nv
idx = CartesianIndex(1, 1, 1, v, 1)
dest[idx] = convert(S, bc[idx])
end
return dest
end
function Base.copyto!(
dest::VIFH{S, Nv, Ni, Nh},
bc::BroadcastedUnionVIFH{S, Nv, Ni, Nh},
::ToCPU,
) where {S, Nv, Ni, Nh}
# copy contiguous columns
@inbounds for h in 1:Nh, i in 1:Ni
col_dest = column(dest, i, h)
col_bc = column(bc, i, h)
copyto!(col_dest, col_bc)
end
return dest
end
function Base.copyto!(
dest::VIJFH{S, Nv, Nij, Nh},
bc::BroadcastedUnionVIJFH{S, Nv, Nij, Nh},
::ToCPU,
) where {S, Nv, Nij, Nh}
# copy contiguous columns
@inbounds for h in 1:Nh, j in 1:Nij, i in 1:Nij
col_dest = column(dest, i, j, h)
col_bc = column(bc, i, j, h)
copyto!(col_dest, col_bc)
end
return dest
end