Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend macros with rand to support custom samplers #1210

Merged
merged 10 commits into from
Jan 3, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StaticArrays"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.6.5"
version = "1.7.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
43 changes: 37 additions & 6 deletions src/SArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,43 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
if length(ex.args) == 1
f === :zeros || f === :ones || error("@$SA got bad expression: $(ex)")
return :($f($SA{$Tuple{},$Float64}))
end
return quote
if isa($(esc(ex.args[2])), DataType)
$f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))})
else
$f($SA{$Tuple{$(escall(ex.args[2:end])...)}})
elseif f !== :rand || length(ex.args) == 2
return quote
if isa($(esc(ex.args[2])), DataType)
mateuszbaran marked this conversation as resolved.
Show resolved Hide resolved
$f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))})
else
$f($SA{$Tuple{$(escall(ex.args[2:end])...)}})
end
end
else
return quote
if isa($(esc(ex.args[2])), DataType)
$f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))})
elseif isa($(esc(ex.args[2])), Integer)
$f($SA{$Tuple{$(escall(ex.args[2:end])...)}})
elseif isa($(esc(ex.args[2])), Random.AbstractRNG)
# for calls like rand(rng::AbstractRNG, sampler, dims::Integer...)
StaticArrays._rand(
$(esc(ex.args[2])),
$(esc(ex.args[3])),
Size($(escall(ex.args[4:end])...)),
$SA{
Tuple{$(escall(ex.args[4:end])...)},
Random.gentype($(esc(ex.args[3]))),
},
)
else
# for calls like rand(sampler, dims::Integer...)
StaticArrays._rand(
Random.GLOBAL_RNG,
$(esc(ex.args[2])),
Size($(escall(ex.args[3:end])...)),
$SA{
Tuple{$(escall(ex.args[3:end])...)},
Random.gentype($(esc(ex.args[2]))),
},
)
end
end
end
elseif f === :fill
Expand Down
24 changes: 23 additions & 1 deletion src/SMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,29 @@ function static_matrix_gen(::Type{SM}, @nospecialize(ex), mod::Module) where {SM
if length(ex.args) == 3
return :($f($SM{$(escall(ex.args[2:3])...), Float64})) # default to Float64 like Base
elseif length(ex.args) == 4
return :($f($SM{$(escall(ex.args[[3,4,2]])...)}))
if f === :rand && ex.args[3] isa Int && ex.args[3] ≥ 0 && ex.args[4] isa Int && ex.args[4] ≥ 0
# for calls like rand(sampler, n, m) or rand(type, n, m)
return quote
StaticArrays._rand(
Random.GLOBAL_RNG,
$(esc(ex.args[2])),
Size($(esc(ex.args[3])), $(esc(ex.args[4]))),
$SM{$(esc(ex.args[3])), $(esc(ex.args[4])), Random.gentype($(esc(ex.args[2])))},
)
end
else
return :($f($SM{$(escall(ex.args[[3,4,2]])...)}))
end
elseif length(ex.args) == 5 && f === :rand && ex.args[4] isa Int && ex.args[4] ≥ 0 && ex.args[5] isa Int && ex.args[5] ≥ 0
# for calls like rand(rng::AbstractRNG, sampler, n, m) or rand(rng::AbstractRNG, type, n, m)
return quote
StaticArrays._rand(
$(esc(ex.args[2])),
$(esc(ex.args[3])),
Size($(esc(ex.args[4])), $(esc(ex.args[5]))),
$SM{$(esc(ex.args[4])), $(esc(ex.args[5])), Random.gentype($(esc(ex.args[3])))},
)
end
else
error("@$SM expected a 2-dimensional array expression")
end
Expand Down
24 changes: 23 additions & 1 deletion src/SVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,29 @@ function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV
if length(ex.args) == 2
return :($f($SV{$(esc(ex.args[2])), Float64})) # default to Float64 like Base
elseif length(ex.args) == 3
return :($f($SV{$(escall(ex.args[3:-1:2])...)}))
if f === :rand && ex.args[3] isa Int && ex.args[3] ≥ 0
# for calls like rand(sampler, n) or rand(type, n)
return quote
StaticArrays._rand(
Random.GLOBAL_RNG,
$(esc(ex.args[2])),
Size($(esc(ex.args[3]))),
$SV{$(esc(ex.args[3])), Random.gentype($(esc(ex.args[2])))},
)
end
else
return :($f($SV{$(escall(ex.args[3:-1:2])...)}))
end
elseif length(ex.args) == 4 && f === :rand && ex.args[4] isa Int && ex.args[4] ≥ 0
# for calls like rand(rng::AbstractRNG, sampler, n) or rand(rng::AbstractRNG, type, n)
return quote
StaticArrays._rand(
$(esc(ex.args[2])),
$(esc(ex.args[3])),
Size($(esc(ex.args[4]))),
$SV{$(esc(ex.args[4])), Random.gentype($(esc(ex.args[3])))},
)
end
else
error("@$SV expected a 1-dimensional array expression")
end
Expand Down
4 changes: 2 additions & 2 deletions src/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ end

@inline rand(rng::AbstractRNG, range::AbstractArray, ::Type{SA}) where {SA <: StaticArray} = _rand(rng, range, Size(SA), SA)
@inline rand(range::AbstractArray, ::Type{SA}) where {SA <: StaticArray} = _rand(Random.GLOBAL_RNG, range, Size(SA), SA)
@generated function _rand(rng::AbstractRNG, range::AbstractArray, ::Size{s}, ::Type{SA}) where {s, SA <: StaticArray}
v = [:(rand(rng, range)) for i = 1:prod(s)]
@generated function _rand(rng::AbstractRNG, X, ::Size{s}, ::Type{SA}) where {s, SA <: StaticArray}
v = [:(rand(rng, X)) for i = 1:prod(s)]
return quote
@_inline_meta
$SA(tuple($(v...)))
Expand Down
42 changes: 42 additions & 0 deletions test/arraymath.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
using StaticArrays, Test
import StaticArrays.arithmetic_closure

struct TestDie
nsides::Int
end
Random.rand(rng::AbstractRNG, d::Random.SamplerTrivial{TestDie}) = rand(rng, 1:d[].nsides)
Base.eltype(::Type{TestDie}) = Int

@testset "Array math" begin
@testset "zeros() and ones()" begin
@test @inferred(zeros(SVector{3,Float64})) === @SVector [0.0, 0.0, 0.0]
Expand Down Expand Up @@ -179,6 +185,42 @@ import StaticArrays.arithmetic_closure
@test v4 isa SA{0, Float32}
@test all(0 .< v4 .< 1)
end
rng = MersenneTwister(123)
@test (@SVector rand(3)) isa SVector{3,Float64}
@test (@SMatrix rand(3, 4)) isa SMatrix{3,4,Float64}
@test (@SArray rand(3, 4, 5)) isa SArray{Tuple{3,4,5},Float64}

@test (@MVector rand(3)) isa MVector{3,Float64}
@test (@MMatrix rand(3, 4)) isa MMatrix{3,4,Float64}
@test (@MArray rand(3, 4, 5)) isa MArray{Tuple{3,4,5},Float64}

@test (@SVector rand(TestDie(6), 3)) isa SVector{3,Int}
@test (@SVector rand(rng, TestDie(6), 3)) isa SVector{3,Int}
@test (@SVector rand(TestDie(6), 0)) isa SVector{0,Int}
@test (@SVector rand(rng, TestDie(6), 0)) isa SVector{0,Int}
@test (@MVector rand(TestDie(6), 3)) isa MVector{3,Int}
@test (@MVector rand(rng, TestDie(6), 3)) isa MVector{3,Int}

@test (@SMatrix rand(TestDie(6), 3, 4)) isa SMatrix{3,4,Int}
@test (@SMatrix rand(rng, TestDie(6), 3, 4)) isa SMatrix{3,4,Int}
@test (@SMatrix rand(TestDie(6), 0, 4)) isa SMatrix{0,4,Int}
@test (@SMatrix rand(rng, TestDie(6), 0, 4)) isa SMatrix{0,4,Int}
@test (@MMatrix rand(TestDie(6), 3, 4)) isa MMatrix{3,4,Int}
@test (@MMatrix rand(rng, TestDie(6), 3, 4)) isa MMatrix{3,4,Int}

@test (@SArray rand(TestDie(6), 3, 4, 5)) isa SArray{Tuple{3,4,5},Int}
@test (@SArray rand(rng, TestDie(6), 3, 4, 5)) isa SArray{Tuple{3,4,5},Int}
@test (@SArray rand(TestDie(6), 0, 4, 5)) isa SArray{Tuple{0,4,5},Int}
@test (@SArray rand(rng, TestDie(6), 0, 4, 5)) isa SArray{Tuple{0,4,5},Int}
@test (@MArray rand(TestDie(6), 3, 4, 5)) isa MArray{Tuple{3,4,5},Int}

# test if rng generator is actually respected
@test (@SVector rand(MersenneTwister(123), TestDie(6), 3)) ===
(@SVector rand(MersenneTwister(123), TestDie(6), 3))
@test (@SMatrix rand(MersenneTwister(123), TestDie(6), 3, 4)) ===
(@SMatrix rand(MersenneTwister(123), TestDie(6), 3, 4))
@test (@SArray rand(MersenneTwister(123), TestDie(6), 3, 4, 5)) ===
(@SArray rand(MersenneTwister(123), TestDie(6), 3, 4, 5))
end

@testset "rand!()" begin
Expand Down
Loading