Skip to content

Commit

Permalink
Chores (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau authored Jul 26, 2024
1 parent 8a57cb8 commit 6344423
Show file tree
Hide file tree
Showing 16 changed files with 65 additions and 27 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,30 @@
Changelog
=========

0.6.2 (unreleased)
------------------

**Other changes**

- Fixed various deprecation warnings


0.6.1 (2024-07-12)
------------------

**Bug fixes**

- Division now complies more strictly with the Array API standard by returning a floating-point result regardless of input data types.


0.6.0 (2024-07-11)
------------------

**Other changes**

- ``ndonnx.promote_nullable`` is now publicly exported.


0.5.0 (2024-07-01)
------------------

Expand All @@ -31,6 +41,7 @@ Changelog

- ``__array_namespace__`` now accepts the optional ``api_version`` argument to specify the version of the Array API to use.


0.4.0 (2024-05-16)
------------------

Expand Down
13 changes: 6 additions & 7 deletions ndonnx/_core/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import ndonnx._data_types as dtypes
import ndonnx._opset_extensions as opx
from ndonnx._utility import promote
from ndonnx.additional import make_nullable

from ._interface import OperationsBlock

Expand Down Expand Up @@ -311,7 +310,7 @@ def sign(self, x):
if x_null is None:
return out_values
else:
return make_nullable(out_values, x_null)
return ndx.additional.make_nullable(out_values, x_null)

def sin(self, x):
return _unary_op(x, opx.sin, dtypes.float64)
Expand Down Expand Up @@ -549,7 +548,7 @@ def where_dtype_agnostic(a: ndx.Array, b: ndx.Array) -> ndx.Array:
if condition.dtype == ndx.nbool:
# propagate null if present
if not isinstance(output.dtype, dtypes.Nullable):
output = make_nullable(output, condition.null)
output = ndx.additional.make_nullable(output, condition.null)
else:
output.null = output.null | condition.null
return output
Expand Down Expand Up @@ -889,9 +888,9 @@ def make_nullable(self, x, null):
if null.dtype != dtypes.bool:
raise TypeError("null must be a boolean array")
if not isinstance(x.dtype, dtypes.CoreType):
raise TypeError("make_nullable does not accept nullable arrays")
raise TypeError("'make_nullable' does not accept nullable arrays")
return ndx.Array._from_fields(
dtypes.promote_nullable(x.dtype),
dtypes.into_nullable(x.dtype),
values=x.copy(),
null=ndx.reshape(null, x.shape),
)
Expand Down Expand Up @@ -939,7 +938,7 @@ def _variadic_op(
values = _via_dtype(op, via_dtype, data, cast_return=cast_return)

if (out_null := functools.reduce(_or_nulls, nulls)) is not None:
dtype = dtypes.promote_nullable(values.dtype)
dtype = dtypes.into_nullable(values.dtype)
return ndx.Array._from_fields(dtype, values=values, null=out_null)
else:
return values
Expand Down Expand Up @@ -1002,7 +1001,7 @@ def _via_dtype(

if out_null is not None:
out_value = ndx.Array._from_fields(
dtypes.promote_nullable(out_value.dtype), values=out_value, null=out_null
dtypes.into_nullable(out_value.dtype), values=out_value, null=out_null
)
else:
out_value = ndx.Array._from_fields(out_value.dtype, data=out_value._core())
Expand Down
22 changes: 11 additions & 11 deletions ndonnx/_data_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
from warnings import warn
from ndonnx._utility import deprecated

from .aliases import (
bool,
Expand Down Expand Up @@ -51,9 +51,8 @@
from .structtype import StructType


# TODO: to be removed
def promote_nullable(dtype: StructType | CoreType) -> _NullableCore:
"""Promotes a non-nullable type to its nullable counterpart, if present.
def into_nullable(dtype: StructType | CoreType) -> _NullableCore:
"""Return nullable counterpart, if present.
Parameters
----------
Expand All @@ -70,13 +69,6 @@ def promote_nullable(dtype: StructType | CoreType) -> _NullableCore:
ValueError
If the input type is unknown to ``ndonnx``.
"""

warn(
"Function 'ndonnx.promote_nullable' will be deprecated in ndonnx 0.7. "
"To create nullable array, use 'ndonnx.additional.make_nullable' instead.",
DeprecationWarning,
)

if dtype == bool:
return nbool
elif dtype == float32:
Expand Down Expand Up @@ -107,6 +99,14 @@ def promote_nullable(dtype: StructType | CoreType) -> _NullableCore:
raise ValueError(f"Cannot promote {dtype} to nullable")


@deprecated(
"Function 'ndonnx.promote_nullable' will be deprecated in ndonnx 0.7. "
"To create nullable array, use 'ndonnx.additional.make_nullable' instead."
)
def promote_nullable(dtype: StructType | CoreType) -> _NullableCore:
return into_nullable(dtype)


__all__ = [
"CoreType",
"StructType",
Expand Down
4 changes: 2 additions & 2 deletions ndonnx/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def asarray(
if dtype is None:
dtype = dtypes.from_numpy_dtype(arr.dtype)
if isinstance(arr, np.ma.masked_array):
dtype = dtypes.promote_nullable(dtype)
dtype = dtypes.into_nullable(dtype)

ret = Array._construct(
shape=arr.shape, dtype=dtype, eager_values=dtype._parse_input(arr)
Expand Down Expand Up @@ -297,7 +297,7 @@ def result_type(

ret_dtype = dtypes.from_numpy_dtype(np.result_type(*np_dtypes))
if nullable:
return dtypes.promote_nullable(ret_dtype)
return dtypes.into_nullable(ret_dtype)
else:
return ret_dtype

Expand Down
16 changes: 16 additions & 0 deletions ndonnx/_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from __future__ import annotations

from functools import wraps
from typing import TYPE_CHECKING
from warnings import warn

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -111,3 +113,17 @@ def unwrap_var(tensor: _CoreArray | Var) -> Var:
return tensor
else:
return tensor.var


def deprecated(msg: str):
"""Decorates a function as deprecated and raises a warning when it is called."""

def _deprecated(fn):
@wraps(fn)
def inner(*args, **kwargs):
warn(msg, DeprecationWarning)
return fn(*args, **kwargs)

return inner

return _deprecated
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ module = ["onnxruntime"]
ignore_missing_imports = true

[tool.pytest.ini_options]
pythonpath = "."
addopts = "--import-mode=importlib --ignore=api-coverage-tests"
testpaths = ["tests"]
exclude = ["docs/"]
addopts = "--ignore=api-coverage-tests"
filterwarnings = [
"ignore:.*google.protobuf.pyext.*:DeprecationWarning",
]

[tool.typos.default]
extend-ignore-identifiers-re = ["scatter_nd", "arange"]
2 changes: 2 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_input_output_name_backwards_compatibility(dtype):
model_proto = ndx.build({"input": a}, {"output": a})
assert [node.name for node in model_proto.graph.input] == ["input"]
assert [node.name for node in model_proto.graph.output] == ["output"]
a = ndx.array(shape=("N",), dtype=ndx._data_types.promote_nullable(dtype))
a = ndx.array(shape=("N",), dtype=ndx._data_types.into_nullable(dtype))
model_proto = ndx.build({"input": a}, {"output": a})
assert [node.name for node in model_proto.graph.input] == [
"input_values",
Expand Down
File renamed without changes.
File renamed without changes.
7 changes: 6 additions & 1 deletion tests/ndonnx/test_core.py → tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
def numpy_to_graph_input(arr, eager=False):
dtype: dtypes.CoreType | dtypes.StructType
if isinstance(arr, np.ma.MaskedArray):
dtype = dtypes.promote_nullable(dtypes.from_numpy_dtype(arr.dtype))
dtype = dtypes.into_nullable(dtypes.from_numpy_dtype(arr.dtype))
else:
dtype = dtypes.from_numpy_dtype(arr.dtype)
return (
Expand Down Expand Up @@ -601,3 +601,8 @@ def test_prod_no_implementation(dtype):
x = ndx.asarray([2, 2]).astype(dtype)
with pytest.raises(TypeError):
ndx.prod(x)


def test_promote_nullable():
with pytest.warns(DeprecationWarning):
assert ndx.promote_nullable(np.int64) == ndx.nint64
File renamed without changes.
File renamed without changes.
7 changes: 6 additions & 1 deletion tests/ndonnx/test_masked.py → tests/test_masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import math
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -116,7 +117,11 @@ def test_unary_none_propagation(fn_name, args, kwargs):
npx = get_numpy_array_api_namespace()
np_fn = getattr(npx, fn_name)
inp_a = npx.asarray(np.ma.filled(inp_a, np.nan))
expected_b = np_fn(inp_a, *args, **kwargs)

# Numpy might complain about invalid values
with warnings.catch_warnings():
warnings.simplefilter("ignore")
expected_b = np_fn(inp_a, *args, **kwargs)
np.testing.assert_almost_equal(
np.ma.masked_array(expected_b, mask=missing_a),
ret_b,
Expand Down
File renamed without changes.

0 comments on commit 6344423

Please sign in to comment.