Skip to content

Commit

Permalink
Add dicts to eager propagation decorator (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 authored Jul 22, 2024
1 parent 7650f9c commit 64b9f37
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ndonnx/_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ def collect_lazy_arguments(obj):
return ndx.asarray(obj_value, obj.dtype)
elif isinstance(obj, (list, tuple)):
return type(obj)(map(collect_lazy_arguments, obj))
elif isinstance(obj, dict):
return {key: collect_lazy_arguments(value) for key, value in obj.items()}
elif isinstance(obj, slice):
return slice(
collect_lazy_arguments(obj.start),
Expand Down
26 changes: 26 additions & 0 deletions tests/ndonnx/test_constant_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

import numpy as np
import pytest
import spox.opset.ai.onnx.v20 as op

import ndonnx as ndx
from ndonnx._propagation import eager_propagate


def test_add():
Expand Down Expand Up @@ -237,3 +239,27 @@ def test_where_folding(cond, x, y, expected_operators):
model_proto = ndx.build(inputs, {"out": out})
operators_used_const = {node.op_type for node in model_proto.graph.node}
assert operators_used_const == expected_operators


def test_eager_propagation_nested_parameters():
@eager_propagate
def function(
x: ndx.Array, mapping: dict[str, ndx.Array], seq: list[ndx.Array]
) -> tuple[ndx.Array, ndx.Array]:
# do some spox stuff
a = ndx.from_spox_var(op.sigmoid(mapping["a"].astype(ndx.float64).spox_var()))
b = ndx.from_spox_var(
op.regex_full_match(seq[0].spox_var(), pattern="^hello.*")
)
return (a + x) * mapping["b"], b

x, y = function(
ndx.asarray([1, 2, 3, 4]),
{"a": ndx.asarray([1, -10, 120, 40]), "b": 10},
[ndx.asarray(["a", "hello world", "world hello"])],
)
expected_x = np.asarray([17.310586, 20.000454, 40.0, 50.0])
expected_y = np.asarray([False, True, False])

np.testing.assert_allclose(x.to_numpy(), expected_x)
np.testing.assert_array_equal(y.to_numpy(), expected_y)

0 comments on commit 64b9f37

Please sign in to comment.