Skip to content

Commit

Permalink
Fix promotion precision loss (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 2, 2024
1 parent 2220369 commit 4e5fa52
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 59 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ Changelog

- Fixed various deprecation warnings.

**Bug fixes**

- Fixes scalar promotion logic to more accurately reflect the Array API standard. Promotion requires at least one array to be present and scalars adopt the dtype of the arrays being promoted with it. `ndx.utf8` and `ndx.nutf8` cannot be promoted with any other dtypes.


0.6.1 (2024-07-12)
------------------
Expand Down
4 changes: 2 additions & 2 deletions ndonnx/_core/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ def log1p(self, x):
return self.add(self.log(x), ndx.asarray(1, x.dtype))

def log2(self, x):
return self.log(x) / np.log(2)
return ndx.log(x) / np.log(2)

def log10(self, x):
return self.log(x) / np.log(10)
return ndx.log(x) / np.log(10)

def logaddexp(self, x, y):
return self.log(self.exp(x) + self.exp(y))
Expand Down
9 changes: 4 additions & 5 deletions ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,10 @@ def arange(
device=None,
):
step = asarray(step)
start = asarray(start)

if stop is None:
stop = start
start = asarray(0)
elif start is None:
start = asarray(0)

start, stop, step = promote(start, stop, step)

Expand Down Expand Up @@ -456,11 +455,11 @@ def log1p(x):


def log2(x):
return log(x) / np.log(2)
return _unary("log2", x)


def log10(x):
return log(x) / np.log(10)
return _unary("log10", x)


def logaddexp(x, y):
Expand Down
71 changes: 23 additions & 48 deletions ndonnx/_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,75 +14,50 @@
# FIXME: Remove private import from Spox! Better to just use reshape!
from spox._internal_op import unsafe_reshape

import ndonnx._data_types as dtypes
import ndonnx as ndx

if TYPE_CHECKING:
from ndonnx import Array

from ._corearray import _CoreArray


def _promote_with_none(*args: Array | npt.ArrayLike) -> list[Array | None]:
def promote(*args: Array | npt.ArrayLike | None) -> list[Array]:
"""Promote arguments following numpy's promotion rules.
Constant scalars are converted to `Array` objects.
`None` values are passed through.
"""
# FIXME: The import structure is rather FUBAR!
from ._array import Array
from ._funcs import asarray, astype, result_type
arrays: list[Array] = []
all_arguments: list[Array] = []

arr_or_none: list[Array | None] = []
scalars: list[Array] = []
signed_integer = False
for arg in args:
if arg is None:
arr_or_none.append(arg)
elif isinstance(arg, Array):
arr_or_none.append(arg)
if arg.dtype in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64):
signed_integer = True
if isinstance(arg, ndx.Array):
arrays.append(arg)
all_arguments.append(arg)
elif isinstance(arg, np.ndarray):
arr_or_none.append(asarray(arg))
elif isinstance(arg, (int, float, str, np.generic)):
np_dtype = np.min_scalar_type(arg)
if np_dtype == np.dtype("float16"):
np_dtype = np.dtype("float32")
arr = asarray(arg, dtypes.from_numpy_dtype(np_dtype))
arr_or_none.append(arr)
scalars.append(arr)
arrays.append(ndx.asarray(arg))
all_arguments.append(arrays[-1])
elif isinstance(arg, (float, int, str, np.generic)):
all_arguments.append(ndx.asarray(arg))
else:
raise TypeError(f"Cannot promote {type(arg)}")

for scalar in scalars:
eager_value = scalar.to_numpy()
if eager_value is None:
raise ValueError("Cannot promote `None` value")
if signed_integer and eager_value.dtype.kind == "u":
# translate dtype to signed version
dtype = np.dtype(f"int{eager_value.dtype.itemsize * 8}")
if eager_value > np.iinfo(dtype).max:
dtype = np.dtype(f"int{eager_value.dtype.itemsize * 16}")
scalar._set(scalar.astype(dtypes.from_numpy_dtype(dtype)))
if not arrays:
raise ValueError("At least one array must be provided for type promotion")

target_dtype = result_type(*[arr for arr in arr_or_none if arr is not None])
target_dtype = ndx.result_type(*arrays)
string_dtypes = (ndx.utf8, ndx.nutf8)
out: list[Array] = []
for arr in all_arguments:
if arr.dtype in string_dtypes and target_dtype not in string_dtypes:
raise TypeError("Cannot promote string scalar to numerical type")
elif arr.dtype not in string_dtypes and target_dtype in string_dtypes:
raise TypeError("Cannot promote non string scalar to string type")
out.append(arr.astype(target_dtype))

return [None if arr is None else astype(arr, target_dtype) for arr in arr_or_none]


def promote(*args: Array | npt.ArrayLike) -> list[Array]:
"""Promote arguments following numpy's promotion rules.
Constant scalars are converted to `Array` objects.
"""
promoted = _promote_with_none(*args)
ret = []
for el in promoted:
if el is None:
raise ValueError("Cannot promote `None` value")
ret.append(el)
return ret
return out


# We assume that rank will be static, because
Expand Down
52 changes: 48 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import ndonnx as ndx
import ndonnx.additional as nda
from ndonnx import _data_types as dtypes
from ndonnx._utility import promote

from .utils import get_numpy_array_api_namespace, run

Expand Down Expand Up @@ -411,14 +412,14 @@ def test_array_spox_interoperability():


def test_creation_arange():
a = ndx.arange(10)
np.testing.assert_equal(np.arange(10), a.to_numpy())
a = ndx.arange(0, stop=10)
np.testing.assert_equal(a.to_numpy(), np.arange(stop=10))

b = ndx.arange(1, 10)
np.testing.assert_equal(np.arange(1, 10), b.to_numpy())
np.testing.assert_equal(b.to_numpy(), np.arange(1, 10))

c = ndx.arange(1, 10, 2)
np.testing.assert_equal(np.arange(1, 10, 2), c.to_numpy())
np.testing.assert_equal(c.to_numpy(), np.arange(1, 10, 2))

d = ndx.arange(0.0, None, step=-1)
np.testing.assert_array_equal(
Expand Down Expand Up @@ -634,3 +635,46 @@ def test_array_creation_with_invalid_fields():
def test_promote_nullable():
with pytest.warns(DeprecationWarning):
assert ndx.promote_nullable(np.int64) == ndx.nint64


# if the precision loss looks concerning, note https://data-apis.org/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars
@pytest.mark.parametrize(
"arrays, scalar",
[
([ndx.array(shape=("N",), dtype=ndx.uint8)], 1),
([ndx.array(shape=("N",), dtype=ndx.uint8)], -1),
([ndx.array(shape=("N",), dtype=ndx.int8)], 1),
([ndx.array(shape=("N",), dtype=ndx.nint8)], 1),
([ndx.array(shape=("N",), dtype=ndx.nint8)], 1),
([ndx.array(shape=("N",), dtype=ndx.float64)], 0.123456789),
([ndx.array(shape=("N",), dtype=ndx.float64)], np.float64(0.123456789)),
([ndx.array(shape=("N",), dtype=ndx.float32)], 0.123456789),
(
[
ndx.array(shape=("N",), dtype=ndx.float32),
ndx.asarray([1.5], dtype=ndx.float64),
],
0.123456789,
),
([ndx.asarray(["a", "b"], dtype=ndx.utf8)], "hello"),
],
)
def test_scalar_promote(arrays, scalar):
args = arrays + [scalar]
*updated_arrays, updated_scalar = promote(*args)
assert all(isinstance(array, ndx.Array) for array in updated_arrays)
assert isinstance(updated_scalar, ndx.Array)
assert updated_scalar.shape == ()
assert all(array.dtype == updated_scalar.dtype for array in updated_arrays)


@pytest.mark.parametrize(
"arrays, scalar",
[
([ndx.asarray(["a", "b"], dtype=ndx.utf8)], 1),
([ndx.asarray([1, 2], dtype=ndx.int32)], "hello"),
],
)
def test_promotion_failures(arrays, scalar):
with pytest.raises(TypeError, match="Cannot promote"):
promote(*arrays, scalar)

0 comments on commit 4e5fa52

Please sign in to comment.