diff --git a/.gitmodules b/.gitmodules index 4562ffa..1f68eb6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,3 @@ -[submodule "array-api-tests"] - path = api-coverage-tests - url = git@github.com:adityagoel4512/array-api-tests.git - branch = skip-unused-kinds-in-default +[submodule "api-coverage-tests"] + path = array-api-tests + url = git@github.com:data-apis/array-api-tests.git diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0468cae..806a5ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,7 +48,7 @@ repos: language: system types: [python] require_serial: true - exclude: ^(tests|api-coverage-tests)/ + exclude: ^(tests|array-api-tests)/ # prettier - id: prettier name: prettier diff --git a/README.md b/README.md index cb0d52a..db3e5e6 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ pytest tests -n auto It has a couple of key features: -- It implements the [`Array API`](https://data-apis.org/array-api/) standard. Standard compliant code can be executed without changes across numerous backends such as like `NumPy`, `JAX` and now `ndonnx`. +- It implements the [`Array API`](https://data-apis.org/array-api/) standard. Standard compliant code can be executed without changes across numerous backends such as like NumPy, JAX and now ndonnx. ```python import numpy as np @@ -93,7 +93,7 @@ In the future we will be enabling a stable API for an extensible data type syste ## Array API coverage -Array API compatibility is tracked in `api-coverage-tests`. Missing coverage is tracked in the `skips.txt` file. Contributions are welcome! +Array API compatibility is tracked in `array-api-tests`. Missing coverage is tracked in the `skips.txt` file. Contributions are welcome! Summary(1119 total): diff --git a/api-coverage-tests b/api-coverage-tests deleted file mode 160000 index aec51e9..0000000 --- a/api-coverage-tests +++ /dev/null @@ -1 +0,0 @@ -Subproject commit aec51e9da724fa387fe7e4fd3ee06367681b3c4a diff --git a/array-api-tests b/array-api-tests new file mode 160000 index 0000000..dad7731 --- /dev/null +++ b/array-api-tests @@ -0,0 +1 @@ +Subproject commit dad773187be07ecc75b3584ad4b73576f6b17c1e diff --git a/ndonnx/_array.py b/ndonnx/_array.py index 9e8960e..a18abee 100644 --- a/ndonnx/_array.py +++ b/ndonnx/_array.py @@ -261,7 +261,7 @@ def device(self): def to_device( self, device: _Device, /, *, stream: int | Any | None = None ) -> Array: - if device is not device: + if device != self.device: raise ValueError("Cannot move Array to a different device") if stream is not None: raise ValueError("The 'stream' parameter is not supported in ndonnx.") diff --git a/ndonnx/_info.py b/ndonnx/_info.py index edbe89e..4e928a2 100644 --- a/ndonnx/_info.py +++ b/ndonnx/_info.py @@ -4,7 +4,7 @@ from __future__ import annotations import ndonnx as ndx -from ndonnx._array import device +from ndonnx._array import _Device, device from ndonnx._data_types import canonical_name @@ -25,16 +25,16 @@ class ArrayNamespaceInfo: ndx.uint64, ] - def capabilities(self) -> dict: + def capabilities(self) -> dict[str, bool]: return { "boolean indexing": True, "data-dependent shapes": True, } - def default_device(self): + def default_device(self) -> _Device: return device - def devices(self) -> list: + def devices(self) -> list[_Device]: return [device] def dtypes( diff --git a/pixi.toml b/pixi.toml index ea21ee0..a34cdc6 100644 --- a/pixi.toml +++ b/pixi.toml @@ -49,7 +49,7 @@ test = "pytest" test-coverage = "pytest --cov=ndonnx --cov-report=xml --cov-report=term-missing" [feature.test.tasks.arrayapitests] -cmd = "pytest api-coverage-tests/array_api_tests/ -v -rfX --json-report --json-report-file=api-coverage-tests.json -n auto --disable-deadline --disable-extension linalg --skips-file=skips.txt --xfails-file=xfails.txt" +cmd = "pytest array_api_tests/ -v -rfX --json-report --json-report-file=api-coverage-tests.json -n auto --disable-deadline --disable-extension linalg --skips-file=skips.txt --xfails-file=xfails.txt" [feature.test.tasks.arrayapitests.env] ARRAY_API_TESTS_MODULE = "ndonnx" ARRAY_API_TESTS_VERSION = "2023.12" diff --git a/pyproject.toml b/pyproject.toml index 9b47f06..0393482 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,14 +77,14 @@ indent-style = "space" python_version = '3.10' no_implicit_optional = true check_untyped_defs = true -exclude = ["api-coverage-tests", "tests"] +exclude = ["array-api-tests", "tests"] [[tool.mypy.overrides]] module = ["onnxruntime"] ignore_missing_imports = true [tool.pytest.ini_options] -addopts = "--ignore=api-coverage-tests" +addopts = "--ignore=array-api-tests" filterwarnings = ["ignore:.*google.protobuf.pyext.*:DeprecationWarning"] [tool.typos.default]