Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Jan 29, 2025
1 parent 32ea1f8 commit 474f768
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
18 changes: 11 additions & 7 deletions ndonnx/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,12 @@ def shape(self) -> tuple[int | None, ...]:

@property
def device(self):
return ndonnx_device
return device

def to_device(
self, device: NdonnxDevice, /, *, stream: int | Any | None = None
self, device: _Device, /, *, stream: int | Any | None = None
) -> Array:
if device is not ndonnx_device:
if device is not device:
raise ValueError("Cannot move Array to a different device")
if stream is not None:
raise ValueError("The 'stream' parameter is not supported in ndonnx.")
Expand Down Expand Up @@ -592,19 +592,23 @@ def any(self, axis: int | None = 0, keepdims: bool | None = False) -> ndx.Array:
return ndx.any(self, axis=axis, keepdims=False)


class NdonnxDevice:
class _Device:
# We would rather not give users the impression that their arrays
# are tied to a specific device when serializing an ONNX graph as
# such a concept does not exist in the ONNX standard.

def __str__(self):
return "ndonnx device"

def __eq__(self, other):
return type(other) is NdonnxDevice
return type(other) is _Device


ndonnx_device = NdonnxDevice()
device = _Device()


__all__ = [
"Array",
"array",
"ndonnx_device",
"device",
]
16 changes: 4 additions & 12 deletions ndonnx/_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

import ndonnx as ndx
from ndonnx._array import ndonnx_device
from ndonnx._array import device
from ndonnx._data_types import canonical_name


Expand Down Expand Up @@ -32,18 +32,14 @@ def capabilities(self) -> dict:
}

def default_device(self):
return ndonnx_device
return device

def devices(self) -> list:
return [ndonnx_device]
return [device]

def dtypes(
self, *, device=None, kind: str | tuple[str, ...] | None = None
) -> dict[str, ndx.CoreType]:
# We don't care for device since we are writing ONNX graphs.
# We would rather not give users the impression that their arrays
# are tied to a specific device when serializing an ONNX graph as
# such a concept does not exist in the ONNX standard.
out: dict[str, ndx.CoreType] = {}
for dtype in self._all_array_api_types:
if kind is None or ndx.isdtype(dtype, kind):
Expand All @@ -54,15 +50,11 @@ def default_dtypes(
self,
*,
device=None,
) -> dict[str, ndx.CoreType | None]:
# See comment in `dtypes` method regarding device.
) -> dict[str, ndx.CoreType]:
return {
"real floating": ndx.float64,
"integral": ndx.int64,
"indexing": ndx.int64,
# We don't support complex numbers yet due to immaturity in the ONNX ecoystem, so "complex floating" is meaningless.
# The Array API standard requires this key to be present so we set it to None.
"complex floating": None,
}


Expand Down

0 comments on commit 474f768

Please sign in to comment.