Skip to content

Commit

Permalink
Fix ones_like
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Oct 1, 2024
1 parent fd81960 commit 4464a38
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Changelog
**Bug fix**

- Fixed a bug in the construction of nullable arrays using :func:`ndonnx.asarray` where the shape of the null field would not match the values field if the provided `np.ma.MaskedArray`'s mask was scalar.
- Fixed a bug in the implementation of :func:`ndonnx.ones_like` for arrays without a concrete shape.


0.9.0 (2024-08-30)
Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def zeros_like(self, x, dtype=None, device=None):
return ndx.zeros(nda.shape(x), dtype=dtype or x.dtype, device=device)

def ones_like(self, x, dtype=None, device=None):
return ndx.ones(x.shape, dtype=dtype or x.dtype, device=device)
return ndx.ones(nda.shape(x), dtype=dtype or x.dtype, device=device)

def make_array(
self,
Expand Down
10 changes: 10 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,16 @@ def test_creation_full():
)


def test_creation_ones_like():
a = ndx.array(shape=("N",), dtype=ndx.int64)
b = ndx.ones_like(a)
model = ndx.build({"a": a}, {"b": b})
assert_array_equal(
run(model, {"a": np.array([1, 2, 3], dtype=np.int64)})["b"],
np.ones(3, dtype=np.int64),
)


@pytest.mark.parametrize(
"args, expected",
[
Expand Down

0 comments on commit 4464a38

Please sign in to comment.