From 474f768ba7719eb0eff4658018929521b13a4ca9 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Wed, 29 Jan 2025 15:54:46 +0100 Subject: [PATCH] Address feedback --- ndonnx/_array.py | 18 +++++++++++------- ndonnx/_info.py | 16 ++++------------ 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/ndonnx/_array.py b/ndonnx/_array.py index 86086d9..9e8960e 100644 --- a/ndonnx/_array.py +++ b/ndonnx/_array.py @@ -256,12 +256,12 @@ def shape(self) -> tuple[int | None, ...]: @property def device(self): - return ndonnx_device + return device def to_device( - self, device: NdonnxDevice, /, *, stream: int | Any | None = None + self, device: _Device, /, *, stream: int | Any | None = None ) -> Array: - if device is not ndonnx_device: + if device is not device: raise ValueError("Cannot move Array to a different device") if stream is not None: raise ValueError("The 'stream' parameter is not supported in ndonnx.") @@ -592,19 +592,23 @@ def any(self, axis: int | None = 0, keepdims: bool | None = False) -> ndx.Array: return ndx.any(self, axis=axis, keepdims=False) -class NdonnxDevice: +class _Device: + # We would rather not give users the impression that their arrays + # are tied to a specific device when serializing an ONNX graph as + # such a concept does not exist in the ONNX standard. + def __str__(self): return "ndonnx device" def __eq__(self, other): - return type(other) is NdonnxDevice + return type(other) is _Device -ndonnx_device = NdonnxDevice() +device = _Device() __all__ = [ "Array", "array", - "ndonnx_device", + "device", ] diff --git a/ndonnx/_info.py b/ndonnx/_info.py index b107878..edbe89e 100644 --- a/ndonnx/_info.py +++ b/ndonnx/_info.py @@ -4,7 +4,7 @@ from __future__ import annotations import ndonnx as ndx -from ndonnx._array import ndonnx_device +from ndonnx._array import device from ndonnx._data_types import canonical_name @@ -32,18 +32,14 @@ def capabilities(self) -> dict: } def default_device(self): - return ndonnx_device + return device def devices(self) -> list: - return [ndonnx_device] + return [device] def dtypes( self, *, device=None, kind: str | tuple[str, ...] | None = None ) -> dict[str, ndx.CoreType]: - # We don't care for device since we are writing ONNX graphs. - # We would rather not give users the impression that their arrays - # are tied to a specific device when serializing an ONNX graph as - # such a concept does not exist in the ONNX standard. out: dict[str, ndx.CoreType] = {} for dtype in self._all_array_api_types: if kind is None or ndx.isdtype(dtype, kind): @@ -54,15 +50,11 @@ def default_dtypes( self, *, device=None, - ) -> dict[str, ndx.CoreType | None]: - # See comment in `dtypes` method regarding device. + ) -> dict[str, ndx.CoreType]: return { "real floating": ndx.float64, "integral": ndx.int64, "indexing": ndx.int64, - # We don't support complex numbers yet due to immaturity in the ONNX ecoystem, so "complex floating" is meaningless. - # The Array API standard requires this key to be present so we set it to None. - "complex floating": None, }