From 64b9f3734a78ead6775c22c8f6c22b06e1bab181 Mon Sep 17 00:00:00 2001 From: Aditya Goel <48102515+adityagoel4512@users.noreply.github.com> Date: Mon, 22 Jul 2024 13:07:33 +0100 Subject: [PATCH] Add dicts to eager propagation decorator (#21) --- ndonnx/_propagation.py | 2 ++ tests/ndonnx/test_constant_propagation.py | 26 +++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/ndonnx/_propagation.py b/ndonnx/_propagation.py index b1fdce1..9efeb12 100644 --- a/ndonnx/_propagation.py +++ b/ndonnx/_propagation.py @@ -133,6 +133,8 @@ def collect_lazy_arguments(obj): return ndx.asarray(obj_value, obj.dtype) elif isinstance(obj, (list, tuple)): return type(obj)(map(collect_lazy_arguments, obj)) + elif isinstance(obj, dict): + return {key: collect_lazy_arguments(value) for key, value in obj.items()} elif isinstance(obj, slice): return slice( collect_lazy_arguments(obj.start), diff --git a/tests/ndonnx/test_constant_propagation.py b/tests/ndonnx/test_constant_propagation.py index 8c2ad69..0a92591 100644 --- a/tests/ndonnx/test_constant_propagation.py +++ b/tests/ndonnx/test_constant_propagation.py @@ -5,8 +5,10 @@ import numpy as np import pytest +import spox.opset.ai.onnx.v20 as op import ndonnx as ndx +from ndonnx._propagation import eager_propagate def test_add(): @@ -237,3 +239,27 @@ def test_where_folding(cond, x, y, expected_operators): model_proto = ndx.build(inputs, {"out": out}) operators_used_const = {node.op_type for node in model_proto.graph.node} assert operators_used_const == expected_operators + + +def test_eager_propagation_nested_parameters(): + @eager_propagate + def function( + x: ndx.Array, mapping: dict[str, ndx.Array], seq: list[ndx.Array] + ) -> tuple[ndx.Array, ndx.Array]: + # do some spox stuff + a = ndx.from_spox_var(op.sigmoid(mapping["a"].astype(ndx.float64).spox_var())) + b = ndx.from_spox_var( + op.regex_full_match(seq[0].spox_var(), pattern="^hello.*") + ) + return (a + x) * mapping["b"], b + + x, y = function( + ndx.asarray([1, 2, 3, 4]), + {"a": ndx.asarray([1, -10, 120, 40]), "b": 10}, + [ndx.asarray(["a", "hello world", "world hello"])], + ) + expected_x = np.asarray([17.310586, 20.000454, 40.0, 50.0]) + expected_y = np.asarray([False, True, False]) + + np.testing.assert_allclose(x.to_numpy(), expected_x) + np.testing.assert_array_equal(y.to_numpy(), expected_y)