diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ec24c23..a4490b8 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,7 @@ Changelog **Bug fix** - Fixed a bug in the construction of nullable arrays using :func:`ndonnx.asarray` where the shape of the null field would not match the values field if the provided `np.ma.MaskedArray`'s mask was scalar. +- Fixed a bug in the implementation of :func:`ndonnx.ones_like` for arrays without a concrete shape. 0.9.0 (2024-08-30) diff --git a/ndonnx/_core/_shapeimpl.py b/ndonnx/_core/_shapeimpl.py index 4e90232..56a3c8d 100644 --- a/ndonnx/_core/_shapeimpl.py +++ b/ndonnx/_core/_shapeimpl.py @@ -252,7 +252,7 @@ def zeros_like(self, x, dtype=None, device=None): return ndx.zeros(nda.shape(x), dtype=dtype or x.dtype, device=device) def ones_like(self, x, dtype=None, device=None): - return ndx.ones(x.shape, dtype=dtype or x.dtype, device=device) + return ndx.ones(nda.shape(x), dtype=dtype or x.dtype, device=device) def make_array( self, diff --git a/tests/test_core.py b/tests/test_core.py index b0eaae3..5063954 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -500,6 +500,16 @@ def test_creation_full(): ) +def test_creation_ones_like(): + a = ndx.array(shape=("N",), dtype=ndx.int64) + b = ndx.ones_like(a) + model = ndx.build({"a": a}, {"b": b}) + assert_array_equal( + run(model, {"a": np.array([1, 2, 3], dtype=np.int64)})["b"], + np.ones(3, dtype=np.int64), + ) + + @pytest.mark.parametrize( "args, expected", [