Skip to content

Commit

Permalink
Merge pull request #52 from dfdx/wct/fix-splatting
Browse files Browse the repository at this point in the history
Fix splatting properly
  • Loading branch information
dfdx authored Nov 20, 2023
2 parents a82adaf + 7868498 commit 0263cdf
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Umlaut"
uuid = "92992a2b-8ce5-4a9c-bb9d-58be9a7dc841"
authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"]
version = "0.6.0"
version = "0.6.1"

[deps]
CompilerPluginTools = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3638"
Expand Down
10 changes: 8 additions & 2 deletions src/trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,10 @@ function code_signature(ctx, v_fargs)
return fargtypes
end

__to_tuple__(x::Tuple) = x
__to_tuple__(x::NamedTuple) = Tuple(x)
__to_tuple__(x::Array) = Tuple(x)
__to_tuple__(x) = __to_tuple__(collect(x))

"""
unsplat!(t::Tracer, v_fargs)
Expand All @@ -497,8 +501,10 @@ function unsplat!(t::Tracer, v_fargs)
if is_tuple
push!(actual_v_args, iter.args...)
else
for i in eachindex(iter.val)
x = push!(t.tape, mkcall(getfield, v, i; line="Umlaut.unsplat!"))
tuple_v = push!(t.tape, mkcall(__to_tuple__, v; line="Umlaut.unsplat!"))
iter = t.tape[tuple_v].val
for i in eachindex(iter)
x = push!(t.tape, mkcall(getfield, tuple_v, i; line="Umlaut.unsplat!"))
push!(actual_v_args, x)
end
end
Expand Down
37 changes: 32 additions & 5 deletions test/test_trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,22 @@ function vararg_fn(x, xs...)
return x + sum(xs)
end

multiarg_fn(x) = x
multiarg_fn(x, y) = x + y
multiarg_fn(x, y, z) = x + y + z
multiarg_fn(x) = only(x)
multiarg_fn(x, y) = only(x) + only(y)
multiarg_fn(x, y, z) = only(x) + only(y) + only(z)


@testset "trace: varargs, splatting" begin

@testset "__to_tuple__" begin
@test Umlaut.__to_tuple__(5.0) == (5.0, )
@test Umlaut.__to_tuple__((5.0, 4.0)) == (5.0, 4.0)
@test Umlaut.__to_tuple__((a=5.0, b="hi")) == (5.0, "hi")
@test Umlaut.__to_tuple__([5.0, 4.0, 3.0]) == (5.0, 4.0, 3.0)
@test Umlaut.__to_tuple__(zip([1.0, 2.0, 3.0])) == ((1.0, ), (2.0, ), (3.0, ))
@test Umlaut.__to_tuple__(Core.svec(1.0, 2.0)) == (1.0, 2.0)
end

# varargs
_, tape = trace(vararg_fn, 1, 2, 3)
@test play!(tape, vararg_fn, 4, 5, 6) == vararg_fn(4, 5, 6)
Expand All @@ -301,8 +311,8 @@ multiarg_fn(x, y, z) = x + y + z
f = t -> multiarg_fn(t...)
_, tape = trace(f, (1, 2))
@test play!(tape, f, (3, 4)) == f((3, 4))
@test tape[V(4)].fn == Base.getfield
@test tape[V(5)].fn == Base.getfield
@test tape[V(6)].fn == Base.getfield

@test_logs (:warn, "Variable %2 had length 2 during tracing, but now has length 3") play!(tape, f, (5, 6, 7))

Expand All @@ -311,8 +321,25 @@ multiarg_fn(x, y, z) = x + y + z
_, tape = trace(f, 1)
@test play!(tape, f, 2) == f(2)
v2 = V(tape, 2)
@test (tape[V(end)].fn == +) && (tape[V(end)].args == [v2, v2, 1])
v6 = V(tape, 6)
if VERSION >= v"1.9"
@test (tape[V(end)].fn == +) && (tape[V(end)].args == [v2, v2, v6])
end

test_f = x -> multiarg_fn(x...)
@testset "$name" for (name, x) in [
("splat single Int", 1),
("splat single Float64", 1.0),
("splat Vector{Float64}", [1.0, 2.0]),
("splat Tuple{Float64, Int}", (5.0, 4)),
("splat NamedTuple(Float64, Int)", (a=5.0, b=2)),
("splat zip", zip([1.0, 2.0, 3.0])),
("splat Core.SimpleVector", Core.svec(1.0, 2.)),
]
@test test_f(x) === test_f(x)
v, tape = trace(test_f, x)
@test v == test_f(x)
end
end


Expand Down

0 comments on commit 0263cdf

Please sign in to comment.