Skip to content

Commit

Permalink
Extend macros with rand to support custom samplers (JuliaArrays#1210)
Browse files Browse the repository at this point in the history
* Extend macros with rand to support custom samplers

* Fix tests

* Update Project.toml

Co-authored-by: Yuto Horikawa <hyrodium@gmail.com>

* Update test/arraymath.jl

Co-authored-by: Yuto Horikawa <hyrodium@gmail.com>

* Update src/SVector.jl

Co-authored-by: Yuto Horikawa <hyrodium@gmail.com>

* Update src/SMatrix.jl

Co-authored-by: Yuto Horikawa <hyrodium@gmail.com>

* Update `SArray` macro

* fix reported issue; support rng in SArray and SMatrix

* Code suggestions for JuliaArrays#1210 (JuliaArrays#1213)

* move `ex.args[2] isa Integer`

* split `if` block

* simplify :zeros and :ones

* refactor :rand

* refactor :randn and :randexp

* update comments

* add _isnonnegvec

* update with `_isnonnegvec`

* add `_isnonnegvec(args, n)` method to check the size of `args`

* fix `@SArray` for `@SArray rand(rng,T,dim)` etc.

* update comments

* update `@SVector` macro

* update `@SMatrix`

* update `@SVector`

* update `@SArray`

* introduce `fargs` variable

* avoid `_isnonnegvec` in `static_matrix_gen`

* avoid `_isnonnegvec` in `static_vector_gen`

* remove unnecessary `_isnonnegvec`

* add `_rng()` function

* update tests on `@SVector` macro

* update tests on `@MVector` macro

* organize test/MMatrix.jl and test/SMatrix.jl

* organize test/MMatrix.jl and test/SMatrix.jl

* update with broken tests

* organize test/MMatrix.jl and test/SMatrix.jl for `rand*` functions

* fix around `broken` key for `@test` macro

* fix zero-length tests

* update `test/SArray.jl` to match `test/MArray.jl`

* update tests for `@SArray ones` etc.

* add supports for `@SArray ones(3-1,2)` etc.

* move block for `fill`

* update macro `@SArray rand(rng,2,3)` to use ordinary dispatches

* update around `@SArray randn` etc.

* remove unnecessary dollars

* simplify `@SArray fill`

* add `@testset "expand_error"`

* update tests for `@SArray rand(...)` etc.

* fix bug in `rand*_with_Val`

* cleanup tests

* update macro `@SMatrix rand(rng,2,3)` to use ordinary dispatches

* update macro `@SVector rand(rng,3)` to use ordinary dispatches

* move block for `fill`

* simplify `_randexp_with_Val`

---------

Co-authored-by: Yuto Horikawa <hyrodium@gmail.com>
  • Loading branch information
2 people authored and avik-pal committed Jan 12, 2024
1 parent dd6c386 commit bc00c6a
Show file tree
Hide file tree
Showing 13 changed files with 826 additions and 179 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StaticArrays"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.8.2"
version = "1.9.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
113 changes: 94 additions & 19 deletions src/SArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,22 +142,65 @@ function parse_cat_ast(ex::Expr)
cat_any(Val(maxdim), Val(catdim), nargs)
end

#=
For example,
* `@SArray rand(2, 3, 4)`
* `@SArray rand(rng, 3, 4)`
will be expanded to the following.
* `_rand_with_Val(SArray, 2, 3, _int2val(2), _int2val(3), Val((4,)))`
* `_rand_with_Val(SArray, 2, 3, _int2val(rng), _int2val(3), Val((4,)))`
The function `_int2val` is required to avoid the following case.
* `_rand_with_Val(SArray, 2, 3, Val(2), Val(3), Val((4,)))`
* `_rand_with_Val(SArray, 2, 3, Val(rng), Val(3), Val((4,)))`
Mutable object such as `rng` cannot be type parameter, and `Val(rng)` throws an error.
=#
_int2val(x::Int) = Val(x)
_int2val(::Any) = nothing
# @SArray zeros(...)
_zeros_with_Val(::Type{SA}, ::Int, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = zeros(SA{Tuple{n1, ns...}})
_zeros_with_Val(::Type{SA}, T::DataType, ::Val, ::Val{ns}) where {SA, ns} = zeros(SA{Tuple{ns...}, T})
# @SArray ones(...)
_ones_with_Val(::Type{SA}, ::Int, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = ones(SA{Tuple{n1, ns...}})
_ones_with_Val(::Type{SA}, T::DataType, ::Val, ::Val{ns}) where {SA, ns} = ones(SA{Tuple{ns...}, T})
# @SArray rand(...)
_rand_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = rand(SA{Tuple{n1,n2,ns...}})
_rand_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(Random.GLOBAL_RNG, T, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
_rand_with_Val(::Type{SA}, sampler, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(Random.GLOBAL_RNG, sampler, Size(n1, ns...), SA{Tuple{n1, ns...}, Random.gentype(sampler)})
_rand_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _rand(rng, Float64, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
_rand_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _rand(rng, T, Size(ns...), SA{Tuple{ns...}, T})
_rand_with_Val(::Type{SA}, rng::AbstractRNG, sampler, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _rand(rng, sampler, Size(ns...), SA{Tuple{ns...}, Random.gentype(sampler)})
# @SArray randn(...)
_randn_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = randn(SA{Tuple{n1,n2,ns...}})
_randn_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randn(Random.GLOBAL_RNG, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
_randn_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randn(rng, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
_randn_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _randn(rng, Size(ns...), SA{Tuple{ns...}, T})
# @SArray randexp(...)
_randexp_with_Val(::Type{SA}, ::Int, ::Int, ::Val{n1}, ::Val{n2}, ::Val{ns}) where {SA, n1, n2, ns} = randexp(SA{Tuple{n1,n2,ns...}})
_randexp_with_Val(::Type{SA}, T::DataType, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randexp(Random.GLOBAL_RNG, Size(n1, ns...), SA{Tuple{n1, ns...}, T})
_randexp_with_Val(::Type{SA}, rng::AbstractRNG, ::Int, ::Nothing, ::Val{n1}, ::Val{ns}) where {SA, n1, ns} = _randexp(rng, Size(n1, ns...), SA{Tuple{n1, ns...}, Float64})
_randexp_with_Val(::Type{SA}, rng::AbstractRNG, T::DataType, ::Nothing, ::Nothing, ::Val{ns}) where {SA, ns} = _randexp(rng, Size(ns...), SA{Tuple{ns...}, T})

escall(args) = Iterators.map(esc, args)
function _isnonnegvec(args)
length(args) == 0 && return false
all(isa.(args, Integer)) && return all(args .≥ 0)
return false
end
function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
if !isa(ex, Expr)
error("Bad input for @$SA")
end
head = ex.head
if head === :vect # vector
return :($SA{$Tuple{$(length(ex.args))}}($tuple($(escall(ex.args)...))))
return :($SA{Tuple{$(length(ex.args))}}($tuple($(escall(ex.args)...))))
elseif head === :ref # typed, vector
return :($SA{$Tuple{$(length(ex.args)-1)},$(esc(ex.args[1]))}($tuple($(escall(ex.args[2:end])...))))
return :($SA{Tuple{$(length(ex.args)-1)},$(esc(ex.args[1]))}($tuple($(escall(ex.args[2:end])...))))
elseif head === :typed_vcat || head === :typed_hcat || head === :typed_ncat # typed, cat
args = parse_cat_ast(ex)
return :($SA{$Tuple{$(size(args)...)},$(esc(ex.args[1]))}($tuple($(escall(args)...))))
return :($SA{Tuple{$(size(args)...)},$(esc(ex.args[1]))}($tuple($(escall(args)...))))
elseif head === :vcat || head === :hcat || head === :ncat # untyped, cat
args = parse_cat_ast(ex)
return :($SA{$Tuple{$(size(args)...)}}($tuple($(escall(args)...))))
return :($SA{Tuple{$(size(args)...)}}($tuple($(escall(args)...))))
elseif head === :comprehension
if length(ex.args) != 1
error("Expected generator in comprehension, e.g. [f(i,j) for i = 1:3, j = 1:3]")
Expand All @@ -173,7 +216,7 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
return quote
let
f($(escall(rng_args)...)) = $(esc(ex.args[1]))
$SA{$Tuple{$(size(exprs)...)}}($tuple($(exprs...)))
$SA{Tuple{$(size(exprs)...)}}($tuple($(exprs...)))
end
end
elseif head === :typed_comprehension
Expand All @@ -192,26 +235,58 @@ function static_array_gen(::Type{SA}, @nospecialize(ex), mod::Module) where {SA}
return quote
let
f($(escall(rng_args)...)) = $(esc(ex.args[1]))
$SA{$Tuple{$(size(exprs)...)},$T}($tuple($(exprs...)))
$SA{Tuple{$(size(exprs)...)},$T}($tuple($(exprs...)))
end
end
elseif head === :call
f = ex.args[1]
if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp
if length(ex.args) == 1
f === :zeros || f === :ones || error("@$SA got bad expression: $(ex)")
return :($f($SA{$Tuple{},$Float64}))
end
return quote
if isa($(esc(ex.args[2])), DataType)
$f($SA{$Tuple{$(escall(ex.args[3:end])...)},$(esc(ex.args[2]))})
else
$f($SA{$Tuple{$(escall(ex.args[2:end])...)}})
end
fargs = ex.args[2:end]
if f === :zeros || f === :ones
_f_with_Val = Symbol(:_, f, :_with_Val)
if length(fargs) == 0
# for calls like `zeros()`
return :($f($SA{Tuple{},$Float64}))
elseif _isnonnegvec(fargs)
# for calls like `zeros(dims...)`
return :($f($SA{Tuple{$(escall(fargs)...)}}))
else
# for calls like `zeros(type)`
# for calls like `zeros(type, dims...)`
return :($_f_with_Val($SA, $(esc(fargs[1])), Val($(esc(fargs[1]))), Val(tuple($(escall(fargs[2:end])...)))))
end
elseif f === :fill
length(ex.args) == 1 && error("@$SA got bad expression: $(ex)")
return :($f($(esc(ex.args[2])), $SA{$Tuple{$(escall(ex.args[3:end])...)}}))
# for calls like `fill(value, dims...)`
return :($f($(esc(fargs[1])), $SA{Tuple{$(escall(fargs[2:end])...)}}))
elseif f === :rand || f === :randn || f === :randexp
_f_with_Val = Symbol(:_, f, :_with_Val)
if length(fargs) == 0
# No support for `@SArray rand()`
error("@$SA got bad expression: $(ex)")
elseif _isnonnegvec(fargs)
# for calls like `rand(dims...)`
return :($f($SA{Tuple{$(escall(fargs)...)}}))
elseif length(fargs) 2
# for calls like `rand(dim1, dim2, dims...)`
# for calls like `rand(type, dim1, dims...)`
# for calls like `rand(sampler, dim1, dims...)`
# for calls like `rand(rng, dim1, dims...)`
# for calls like `rand(rng, type, dims...)`
# for calls like `rand(rng, sampler, dims...)`
# for calls like `randn(dim1, dim2, dims...)`
# for calls like `randn(type, dim1, dims...)`
# for calls like `randn(rng, dim1, dims...)`
# for calls like `randn(rng, type, dims...)`
# for calls like `randexp(dim1, dim2, dims...)`
# for calls like `randexp(type, dim1, dims...)`
# for calls like `randexp(rng, dim1, dims...)`
# for calls like `randexp(rng, type, dims...)`
return :($_f_with_Val($SA, $(esc(fargs[1])), $(esc(fargs[2])), _int2val($(esc(fargs[1]))), _int2val($(esc(fargs[2]))), Val(tuple($(escall(fargs[3:end])...)))))
elseif length(fargs) == 1
# for calls like `rand(dim)`
return :($f($SA{Tuple{$(escall(fargs)...)}}))
else
error("@$SA got bad expression: $(ex)")
end
else
error("@$SA only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.")
end
Expand Down
64 changes: 54 additions & 10 deletions src/SMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ function check_matrix_size(x::Tuple, T = :S)
x1, x2
end

# @SMatrix rand(...)
_rand_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = rand(rng, SM{n1, n2})
_rand_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(Random.GLOBAL_RNG, T, Size(n1, n2), SM{n1, n2, T})
_rand_with_Val(::Type{SM}, sampler, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(Random.GLOBAL_RNG, sampler, Size(n1, n2), SM{n1, n2, Random.gentype(sampler)})
_rand_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = rand(rng, SM{n1, n2, T})
_rand_with_Val(::Type{SM}, rng::AbstractRNG, sampler, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _rand(rng, sampler, Size(n1, n2), SM{n1, n2, Random.gentype(sampler)})
# @SMatrix randn(...)
_randn_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randn(rng, SM{n1, n2})
_randn_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _randn(Random.GLOBAL_RNG, Size(n1, n2), SM{n1, n2, T})
_randn_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randn(rng, SM{n1, n2, T})
# @SMatrix randexp(...)
_randexp_with_Val(::Type{SM}, rng::AbstractRNG, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randexp(rng, SM{n1, n2})
_randexp_with_Val(::Type{SM}, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = _randexp(Random.GLOBAL_RNG, Size(n1, n2), SM{n1, n2, T})
_randexp_with_Val(::Type{SM}, rng::AbstractRNG, T::DataType, ::Val{n1}, ::Val{n2}) where {SM, n1, n2} = randexp(rng, SM{n1, n2, T})

function static_matrix_gen(::Type{SM}, @nospecialize(ex), mod::Module) where {SM}
if !isa(ex, Expr)
error("Bad input for @$SM")
Expand Down Expand Up @@ -69,22 +84,51 @@ function static_matrix_gen(::Type{SM}, @nospecialize(ex), mod::Module) where {SM
end
elseif head === :call
f = ex.args[1]
if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp
if length(ex.args) == 3
return :($f($SM{$(escall(ex.args[2:3])...), Float64})) # default to Float64 like Base
elseif length(ex.args) == 4
return :($f($SM{$(escall(ex.args[[3,4,2]])...)}))
fargs = ex.args[2:end]
if f === :zeros || f === :ones
if length(fargs) == 2
# for calls like `zeros(dim1, dim2)`
return :($f($SM{$(escall(fargs)...)}))
elseif length(fargs[2:end]) == 2
# for calls like `zeros(type, dim1, dim2)`
return :($f($SM{$(escall(fargs[2:end])...), $(esc(fargs[1]))}))
else
error("@$SM expected a 2-dimensional array expression")
error("@$SM got bad expression: $(ex)")
end
elseif ex.args[1] === :fill
if length(ex.args) == 4
return :($f($(esc(ex.args[2])), $SM{$(escall(ex.args[3:4])...)}))
elseif f === :fill
# for calls like `fill(value, dim1, dim2)`
if length(fargs[2:end]) == 2
return :($f($(esc(fargs[1])), $SM{$(escall(fargs[2:end])...)}))
else
error("@$SM expected a 2-dimensional array expression")
end
elseif f === :rand || f === :randn || f === :randexp
_f_with_Val = Symbol(:_, f, :_with_Val)
if length(fargs) == 2
# for calls like `rand(dim1, dim2)`
# for calls like `randn(dim1, dim2)`
# for calls like `randexp(dim1, dim2)`
return :($f($SM{$(escall(fargs)...)}))
elseif length(fargs) == 3
# for calls like `rand(rng, dim1, dim2)`
# for calls like `rand(type, dim1, dim2)`
# for calls like `rand(sampler, dim1, dim2)`
# for calls like `randn(rng, dim1, dim2)`
# for calls like `randn(type, dim1, dim2)`
# for calls like `randexp(rng, dim1, dim2)`
# for calls like `randexp(type, dim1, dim2)`
return :($_f_with_Val($SM, $(esc(fargs[1])), Val($(esc(fargs[2]))), Val($(esc(fargs[3])))))
elseif length(fargs) == 4
# for calls like `rand(rng, type, dim1, dim2)`
# for calls like `rand(rng, sampler, dim1, dim2)`
# for calls like `randn(rng, type, dim1, dim2)`
# for calls like `randexp(rng, type, dim1, dim2)`
return :($_f_with_Val($SM, $(esc(fargs[1])), $(esc(fargs[2])), Val($(esc(fargs[3]))), Val($(esc(fargs[4])))))
else
error("@$SM got bad expression: $(ex)")
end
else
error("@$SM only supports the zeros(), ones(), rand(), randn(), and randexp() functions.")
error("@$SM only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.")
end
else
error("Bad input for @$SM")
Expand Down
64 changes: 54 additions & 10 deletions src/SVector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@ function check_vector_length(x::Tuple, T = :S)
length(x) >= 1 ? x[1] : 1
end

# @SVector rand(...)
_rand_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = rand(rng, SV{n})
_rand_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _rand(Random.GLOBAL_RNG, T, Size(n), SV{n, T})
_rand_with_Val(::Type{SV}, sampler, ::Val{n}) where {SV, n} = _rand(Random.GLOBAL_RNG, sampler, Size(n), SV{n, Random.gentype(sampler)})
_rand_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = rand(rng, SV{n, T})
_rand_with_Val(::Type{SV}, rng::AbstractRNG, sampler, ::Val{n}) where {SV, n} = _rand(rng, sampler, Size(n), SV{n, Random.gentype(sampler)})
# @SVector randn(...)
_randn_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = randn(rng, SV{n})
_randn_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _randn(Random.GLOBAL_RNG, Size(n), SV{n, T})
_randn_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = randn(rng, SV{n, T})
# @SVector randexp(...)
_randexp_with_Val(::Type{SV}, rng::AbstractRNG, ::Val{n}) where {SV, n} = randexp(rng, SV{n})
_randexp_with_Val(::Type{SV}, T::DataType, ::Val{n}) where {SV, n} = _randexp(Random.GLOBAL_RNG, Size(n), SV{n, T})
_randexp_with_Val(::Type{SV}, rng::AbstractRNG, T::DataType, ::Val{n}) where {SV, n} = randexp(rng, SV{n, T})

function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV}
if !isa(ex, Expr)
error("Bad input for @$SV")
Expand Down Expand Up @@ -74,22 +89,51 @@ function static_vector_gen(::Type{SV}, @nospecialize(ex), mod::Module) where {SV
end
elseif head === :call
f = ex.args[1]
if f === :zeros || f === :ones || f === :rand || f === :randn || f === :randexp
if length(ex.args) == 2
return :($f($SV{$(esc(ex.args[2])), Float64})) # default to Float64 like Base
elseif length(ex.args) == 3
return :($f($SV{$(escall(ex.args[3:-1:2])...)}))
fargs = ex.args[2:end]
if f === :zeros || f === :ones
if length(fargs) == 1
# for calls like `zeros(dim)`
return :($f($SV{$(esc(fargs[1]))}))
elseif length(fargs) == 2
# for calls like `zeros(type, dim)`
return :($f($SV{$(esc(fargs[2])), $(esc(fargs[1]))}))
else
error("@$SV expected a 1-dimensional array expression")
error("@$SV got bad expression: $(ex)")
end
elseif ex.args[1] === :fill
if length(ex.args) == 3
return :($f($(esc(ex.args[2])), $SV{$(esc(ex.args[3]))}))
elseif f === :fill
# for calls like `fill(value, dim)`
if length(fargs) == 2
return :($f($(esc(fargs[1])), $SV{$(esc(fargs[2]))}))
else
error("@$SV expected a 1-dimensional array expression")
end
elseif f === :rand || f === :randn || f === :randexp
_f_with_Val = Symbol(:_, f, :_with_Val)
if length(fargs) == 1
# for calls like `rand(dim)`
# for calls like `randn(dim)`
# for calls like `randexp(dim)`
return :($f($SV{$(escall(fargs)...)}))
elseif length(fargs) == 2
# for calls like `rand(rng, dim)`
# for calls like `rand(type, dim)`
# for calls like `rand(sampler, dim)`
# for calls like `randn(rng, dim)`
# for calls like `randn(type, dim)`
# for calls like `randexp(rng, dim)`
# for calls like `randexp(type, dim)`
return :($_f_with_Val($SV, $(esc(fargs[1])), Val($(esc(fargs[2])))))
elseif length(fargs) == 3
# for calls like `rand(rng, type, dim)`
# for calls like `rand(rng, sampler, dim)`
# for calls like `randn(rng, type, dim)`
# for calls like `randexp(rng, type, dim)`
return :($_f_with_Val($SV, $(esc(fargs[1])), $(esc(fargs[2])), Val($(esc(fargs[3])))))
else
error("@$SV got bad expression: $(ex)")
end
else
error("@$SV only supports the zeros(), ones(), rand(), randn() and randexp() functions.")
error("@$SV only supports the zeros(), ones(), fill(), rand(), randn(), and randexp() functions.")
end
else
error("Use @$SV [a,b,c], @$SV Type[a,b,c] or a comprehension like @$SV [f(i) for i = i_min:i_max]")
Expand Down
4 changes: 2 additions & 2 deletions src/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ end

@inline rand(rng::AbstractRNG, range::AbstractArray, ::Type{SA}) where {SA <: StaticArray} = _rand(rng, range, Size(SA), SA)
@inline rand(range::AbstractArray, ::Type{SA}) where {SA <: StaticArray} = _rand(Random.GLOBAL_RNG, range, Size(SA), SA)
@generated function _rand(rng::AbstractRNG, range::AbstractArray, ::Size{s}, ::Type{SA}) where {s, SA <: StaticArray}
v = [:(rand(rng, range)) for i = 1:prod(s)]
@generated function _rand(rng::AbstractRNG, X, ::Size{s}, ::Type{SA}) where {s, SA <: StaticArray}
v = [:(rand(rng, X)) for i = 1:prod(s)]
return quote
@_inline_meta
$SA(tuple($(v...)))
Expand Down
Loading

0 comments on commit bc00c6a

Please sign in to comment.