|
| 1 | +#! format: off |
| 2 | +# ============================================================ Adapted from Base.Broadcast (julia version 1.10.4) |
| 3 | +import Base.Broadcast: BroadcastStyle |
| 4 | +struct NonExtrudedBroadcasted{ |
| 5 | + Style <: Union{Nothing, BroadcastStyle}, |
| 6 | + Axes, |
| 7 | + F, |
| 8 | + Args <: Tuple, |
| 9 | +} <: Base.AbstractBroadcasted |
| 10 | + style::Style |
| 11 | + f::F |
| 12 | + args::Args |
| 13 | + axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `NonExtrudedBroadcasted`) |
| 14 | + |
| 15 | + NonExtrudedBroadcasted(style::Union{Nothing, BroadcastStyle}, f::Tuple, args::Tuple) = |
| 16 | + error() # disambiguation: tuple is not callable |
| 17 | + function NonExtrudedBroadcasted( |
| 18 | + style::Union{Nothing, BroadcastStyle}, |
| 19 | + f::F, |
| 20 | + args::Tuple, |
| 21 | + axes = nothing, |
| 22 | + ) where {F} |
| 23 | + # using Core.Typeof rather than F preserves inferrability when f is a type |
| 24 | + return new{typeof(style), typeof(axes), Core.Typeof(f), typeof(args)}( |
| 25 | + style, |
| 26 | + f, |
| 27 | + args, |
| 28 | + axes, |
| 29 | + ) |
| 30 | + end |
| 31 | + function NonExtrudedBroadcasted(f::F, args::Tuple, axes = nothing) where {F} |
| 32 | + NonExtrudedBroadcasted(combine_styles(args...)::BroadcastStyle, f, args, axes) |
| 33 | + end |
| 34 | + function NonExtrudedBroadcasted{Style}(f::F, args, axes = nothing) where {Style, F} |
| 35 | + return new{Style, typeof(axes), Core.Typeof(f), typeof(args)}( |
| 36 | + Style()::Style, |
| 37 | + f, |
| 38 | + args, |
| 39 | + axes, |
| 40 | + ) |
| 41 | + end |
| 42 | + function NonExtrudedBroadcasted{Style, Axes, F, Args}( |
| 43 | + f, |
| 44 | + args, |
| 45 | + axes, |
| 46 | + ) where {Style, Axes, F, Args} |
| 47 | + return new{Style, Axes, F, Args}(Style()::Style, f, args, axes) |
| 48 | + end |
| 49 | +end |
| 50 | + |
| 51 | +@inline to_non_extruded_broadcasted(bc::Base.Broadcast.Broadcasted) = |
| 52 | + NonExtrudedBroadcasted(bc.style, bc.f, to_non_extruded_broadcasted(bc.args), bc.axes) |
| 53 | +@inline to_non_extruded_broadcasted(x) = x |
| 54 | +NonExtrudedBroadcasted(bc::Base.Broadcast.Broadcasted) = to_non_extruded_broadcasted(bc) |
| 55 | + |
| 56 | +@inline to_non_extruded_broadcasted(args::Tuple) = ( |
| 57 | + to_non_extruded_broadcasted(args[1]), |
| 58 | + to_non_extruded_broadcasted(Base.tail(args))..., |
| 59 | +) |
| 60 | +@inline to_non_extruded_broadcasted(args::Tuple{Any}) = |
| 61 | + (to_non_extruded_broadcasted(args[1]),) |
| 62 | +@inline to_non_extruded_broadcasted(args::Tuple{}) = () |
| 63 | + |
| 64 | +@inline _checkbounds(bc, _, I) = nothing # TODO: fix this case |
| 65 | +@inline _checkbounds(bc, ::Tuple, I) = Base.checkbounds(bc, I) |
| 66 | +@inline function Base.getindex( |
| 67 | + bc::NonExtrudedBroadcasted, |
| 68 | + I::Union{Integer, CartesianIndex}, |
| 69 | +) |
| 70 | + @boundscheck _checkbounds(bc, axes(bc), I) # is this really the only issue? |
| 71 | + @inbounds _broadcast_getindex(bc, I) |
| 72 | +end |
| 73 | + |
| 74 | +# --- here, we define our own bounds checks |
| 75 | +@inline function Base.checkbounds(bc::NonExtrudedBroadcasted, I::Integer) |
| 76 | + # Base.checkbounds_indices(Bool, axes(bc), (I,)) || Base.throw_boundserror(bc, (I,)) # from Base |
| 77 | + Base.checkbounds_indices(Bool, (Base.OneTo(n_dofs(bc)),), (I,)) || Base.throw_boundserror(bc, (I,)) |
| 78 | +end |
| 79 | + |
| 80 | +import StaticArrays |
| 81 | +to_tuple(t::Tuple) = t |
| 82 | +to_tuple(t::NTuple{N, <: Base.OneTo}) where {N} = map(x->x.stop, t) |
| 83 | +to_tuple(t::NTuple{N, <: StaticArrays.SOneTo}) where {N} = map(x->x.stop, t) |
| 84 | +n_dofs(bc::NonExtrudedBroadcasted) = prod(to_tuple(axes(bc))) |
| 85 | +# --- |
| 86 | + |
| 87 | +Base.@propagate_inbounds _broadcast_getindex( |
| 88 | + A::Union{Ref, AbstractArray{<:Any, 0}, Number}, |
| 89 | + I::Integer, |
| 90 | +) = A[] # Scalar-likes can just ignore all indices |
| 91 | +Base.@propagate_inbounds _broadcast_getindex( |
| 92 | + ::Ref{Type{T}}, |
| 93 | + I::Integer, |
| 94 | +) where {T} = T |
| 95 | +# Tuples are statically known to be singleton or vector-like |
| 96 | +Base.@propagate_inbounds _broadcast_getindex(A::Tuple{Any}, I::Integer) = A[1] |
| 97 | +Base.@propagate_inbounds _broadcast_getindex(A::Tuple, I::Integer) = A[I[1]] |
| 98 | +# Everything else falls back to dynamically dropping broadcasted indices based upon its axes |
| 99 | +# Base.@propagate_inbounds _broadcast_getindex(A, I) = A[newindex(A, I)] |
| 100 | +Base.@propagate_inbounds _broadcast_getindex(A, I::Integer) = A[I] |
| 101 | +Base.@propagate_inbounds function _broadcast_getindex( |
| 102 | + bc::NonExtrudedBroadcasted{<:Any, <:Any, <:Any, <:Any}, |
| 103 | + I::Integer, |
| 104 | +) |
| 105 | + args = _getindex(bc.args, I) |
| 106 | + return _broadcast_getindex_evalf(bc.f, args...) |
| 107 | +end |
| 108 | +@inline _broadcast_getindex_evalf(f::Tf, args::Vararg{Any, N}) where {Tf, N} = |
| 109 | + f(args...) # not propagate_inbounds |
| 110 | +Base.@propagate_inbounds _getindex(args::Tuple, I) = |
| 111 | + (_broadcast_getindex(args[1], I), _getindex(Base.tail(args), I)...) |
| 112 | +Base.@propagate_inbounds _getindex(args::Tuple{Any}, I) = |
| 113 | + (_broadcast_getindex(args[1], I),) |
| 114 | +Base.@propagate_inbounds _getindex(args::Tuple{}, I) = () |
| 115 | + |
| 116 | +@inline Base.axes(bc::NonExtrudedBroadcasted) = _axes(bc, bc.axes) |
| 117 | +_axes(::NonExtrudedBroadcasted, axes::Tuple) = axes |
| 118 | +@inline _axes(bc::NonExtrudedBroadcasted, ::Nothing) = Base.Broadcast.combine_axes(bc.args...) |
| 119 | +_axes(bc::NonExtrudedBroadcasted{<:Base.Broadcast.AbstractArrayStyle{0}}, ::Nothing) = () |
| 120 | +@inline Base.axes(bc::NonExtrudedBroadcasted{<:Any, <:NTuple{N}}, d::Integer) where {N} = |
| 121 | + d <= N ? axes(bc)[d] : OneTo(1) |
| 122 | +Base.IndexStyle(::Type{<:NonExtrudedBroadcasted{<:Any, <:Tuple{Any}}}) = IndexLinear() |
| 123 | +@inline _axes(::NonExtrudedBroadcasted, axes) = axes |
| 124 | +@inline Base.eltype(bc::NonExtrudedBroadcasted) = Base.Broadcast.combine_axes(bc.args...) |
| 125 | + |
| 126 | + |
| 127 | +# ============================================================ |
| 128 | + |
| 129 | +#! format: on |
| 130 | +# Datalayouts |
| 131 | +@propagate_inbounds function linear_getindex( |
| 132 | + data::AbstractData{S}, |
| 133 | + I::Integer, |
| 134 | +) where {S} |
| 135 | + s_array = farray_size(data) |
| 136 | + ss = StaticSize(s_array, field_dim(data)) |
| 137 | + @inbounds get_struct_linear(parent(data), S, Val(field_dim(data)), I, ss) |
| 138 | +end |
| 139 | +@propagate_inbounds function linear_setindex!( |
| 140 | + data::AbstractData{S}, |
| 141 | + val, |
| 142 | + I::Integer, |
| 143 | +) where {S} |
| 144 | + s_array = farray_size(data) |
| 145 | + ss = StaticSize(s_array, field_dim(data)) |
| 146 | + @inbounds set_struct_linear!( |
| 147 | + parent(data), |
| 148 | + convert(S, val), |
| 149 | + Val(field_dim(data)), |
| 150 | + I, |
| 151 | + ss, |
| 152 | + ) |
| 153 | +end |
| 154 | + |
| 155 | +for DL in (:IJKFVH, :IJFH, :IFH, :IJF, :IF, :VF, :VIJFH, :VIFH) # Skip DataF, since we want that to MethodError. |
| 156 | + @eval @propagate_inbounds Base.getindex(data::$(DL), I::Integer) = |
| 157 | + linear_getindex(data, I) |
| 158 | + @eval @propagate_inbounds Base.setindex!(data::$(DL), val, I::Integer) = |
| 159 | + linear_setindex!(data, val, I) |
| 160 | +end |
0 commit comments