From 8e2a324c8155a1eb9d75e73ed5051a39e83eb873 Mon Sep 17 00:00:00 2001 From: Thomas Christensen Date: Fri, 18 Oct 2024 12:43:58 +0200 Subject: [PATCH] add `init` kwarg to `sum` and `prod` as well (fix #1119) --- src/mapreduce.jl | 12 ++++++------ test/mapreduce.jl | 8 ++++++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 86f99579..2df6cd04 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -284,13 +284,13 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) = # TODO: change to use Base.reduce_empty/Base.reduce_first @inline iszero(a::StaticArray{<:Tuple,T}) where {T} = reduce((x,y) -> x && iszero(y), a, init=true) -@inline sum(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(+, a, dims) -@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) -@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) # avoid ambiguity +@inline sum(a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _reduce(+, a, dims, init) +@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, +, dims, init, Size(a), a) +@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, +, dims, init, Size(a), a) # avoid ambiguity -@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(*, a, dims) -@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, _InitialValue(), Size(a), a) -@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, _InitialValue(), Size(a), a) +@inline prod(a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _reduce(*, a, dims, init) +@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, *, dims, init, Size(a), a) +@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:, init=_InitialValue()) where {T} = _mapreduce(f, *, dims, init, Size(a), a) @inline count(a::StaticArray{<:Tuple,Bool}; dims=:, init=0) = _reduce(+, a, dims, init) @inline count(f, a::StaticArray; dims=:, init=0) = _mapreduce(x->f(x)::Bool, +, dims, init, Size(a), a) diff --git a/test/mapreduce.jl b/test/mapreduce.jl index d8ced2d9..3dee49f2 100644 --- a/test/mapreduce.jl +++ b/test/mapreduce.jl @@ -130,18 +130,22 @@ using Statistics: mean @test sum(sa, dims=Val(2)) === RSArray2(sum(a, dims=2)) @test sum(abs2, sa; dims=2) === RSArray2(sum(abs2, a, dims=2)) @test sum(abs2, sa; dims=Val(2)) === RSArray2(sum(abs2, a, dims=2)) + @test sum(sa, init=2) == sum(a, init=2) == sum(sa) + 2 + @test sum(sb, init=2) == sum(b, init=2) == sum(sb) + 2 @test prod(sa) === prod(a) @test prod(abs2, sa) === prod(abs2, a) @test prod(sa, dims=Val(2)) === RSArray2(prod(a, dims=2)) @test prod(abs2, sa, dims=Val(2)) === RSArray2(prod(abs2, a, dims=2)) + @test prod(sa, init=2) == prod(a, init=2) == 2*prod(sa) + @test prod(sb, init=2) == prod(b, init=2) == 2*prod(sb) @test count(sb) === count(b) - @test count(sb, init=3) == count(b, init=3) == count(sb) + 3 @test count(x->x>0, sa) === count(x->x>0, a) - @test count(x->x>0, sa, init=-2) == count(x->x>0, a, init=-2) == count(x->x>0, sa) - 2 @test count(sb, dims=Val(2)) === RSArray2(reshape([count(b[i,:,k]) for i = 1:I, k = 1:K], (I,1,K))) @test count(x->x>0, sa, dims=Val(2)) === RSArray2(reshape([count(x->x>0, a[i,:,k]) for i = 1:I, k = 1:K], (I,1,K))) + @test count(sb, init=3) == count(b, init=3) == count(sb) + 3 + @test count(x->x>0, sa, init=-2) == count(x->x>0, a, init=-2) == count(x->x>0, sa) - 2 @test all(sb) === all(b) @test all(x->x>0, sa) === all(x->x>0, a)