Skip to content

Commit

Permalink
Allow lazy fill with scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Jul 25, 2024
1 parent 4583e13 commit 0cd3bfc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
10 changes: 6 additions & 4 deletions ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,12 @@ def full(
raise ValueError("fill_value must be a scalar")

dtype = fill_value.dtype if dtype is None else dtype
shape = asarray(shape, dtype=dtypes.int64)._core()
return fill_value._transmute(lambda corearray: opx.expand(corearray, shape)).astype(
dtype
)
shape = asarray(shape, dtype=dtypes.int64)
if shape.ndim == 0:
shape = reshape(shape, [1])
return fill_value._transmute(
lambda corearray: opx.expand(corearray, shape._core())
).astype(dtype)


def full_like(
Expand Down
19 changes: 19 additions & 0 deletions tests/ndonnx/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,25 @@ def test_creation_full():
c = ndx.full((2, 3), "a", dtype=ndx.nutf8)
np.testing.assert_equal(np.full((2, 3), "a"), c.to_numpy())

d = ndx.full(2, 5, dtype=ndx.int8)
np.testing.assert_equal(np.full(2, 5, dtype=np.int8), d.to_numpy())

# Check lazy creation
e = ndx.array(shape=tuple(), dtype=ndx.int64)
f = ndx.full(e, 10)
model_proto = ndx.build({"e": e}, {"f": f})
actual = run(model_proto, {"e": np.array(5, dtype=np.int64)})["f"]
np.testing.assert_equal(np.array([10] * 5, dtype=np.int64), actual)

# Note we must know the output shape to export an ONNX artifact.
g = ndx.array(shape=(2,), dtype=ndx.int64)
h = ndx.full(g, 10)
model_proto = ndx.build({"g": g}, {"h": h})
np.testing.assert_equal(
np.array([[10, 10, 10], [10, 10, 10]]),
run(model_proto, {"g": np.array([2, 3])})["h"],
)


@pytest.mark.parametrize(
"args, expected",
Expand Down

0 comments on commit 0cd3bfc

Please sign in to comment.