From 513cd6faf694c82fe6ee8eb2172b57ffbd073c8e Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Fri, 24 Jan 2025 19:52:32 +0000 Subject: [PATCH] Fix defaults --- ndonnx/_data_types/aliases.py | 69 ++++++++++++++++++----------------- ndonnx/_info.py | 29 ++++++++++++--- 2 files changed, 59 insertions(+), 39 deletions(-) diff --git a/ndonnx/_data_types/aliases.py b/ndonnx/_data_types/aliases.py index 5f1bc0f..a9d625d 100644 --- a/ndonnx/_data_types/aliases.py +++ b/ndonnx/_data_types/aliases.py @@ -63,47 +63,50 @@ nutf8: NUtf8 = NUtf8() +_canonical_names = { + bool: "bool", + float32: "float32", + float64: "float64", + int8: "int8", + int16: "int16", + int32: "int32", + int64: "int64", + uint8: "uint8", + uint16: "uint16", + uint32: "uint32", + uint64: "uint64", + utf8: "utf8", +} + + def canonical_name(dtype: CoreType) -> str: """Return the canonical name of the data type.""" - if dtype == bool: - return "bool" - elif dtype == float32: - return "float32" - elif dtype == float64: - return "float64" - elif dtype == int8: - return "int8" - elif dtype == int16: - return "int16" - elif dtype == int32: - return "int32" - elif dtype == int64: - return "int64" - elif dtype == uint8: - return "uint8" - elif dtype == uint16: - return "uint16" - elif dtype == uint32: - return "uint32" - elif dtype == uint64: - return "uint64" - elif dtype == utf8: - return "utf8" + if dtype in _canonical_names: + return _canonical_names[dtype] else: raise ValueError(f"Unknown data type: {dtype}") +_kinds = { + bool: ("bool",), + int8: ("signed integer", "integer", "numeric"), + int16: ("signed integer", "integer", "numeric"), + int32: ("signed integer", "integer", "numeric"), + int64: ("signed integer", "integer", "numeric"), + uint8: ("unsigned integer", "integer", "numeric"), + uint16: ("unsigned integer", "integer", "numeric"), + uint32: ("unsigned integer", "integer", "numeric"), + uint64: ("unsigned integer", "integer", "numeric"), + float32: ("floating", "numeric"), + float64: ("floating", "numeric"), +} + + def kinds(dtype: CoreType) -> tuple[str, ...]: """Return the kinds of the data type.""" - if dtype in (bool,): - return ("bool",) - if dtype in (int8, int16, int32, int64): - return ("signed integer", "integer", "numeric") - if dtype in (uint8, uint16, uint32, uint64): - return ("unsigned integer", "integer", "numeric") - if dtype in (float32, float64): - return ("floating", "numeric") - if dtype in (utf8,): + if dtype in _kinds: + return _kinds[dtype] + elif dtype in (utf8,): raise ValueError(f"We don't get define a kind for {dtype}") else: raise ValueError(f"Unknown data type: {dtype}") diff --git a/ndonnx/_info.py b/ndonnx/_info.py index d90f92b..cf45b65 100644 --- a/ndonnx/_info.py +++ b/ndonnx/_info.py @@ -3,6 +3,8 @@ from __future__ import annotations +from collections.abc import Iterable + import ndonnx as ndx from ndonnx._array import ndonnx_device from ndonnx._data_types import canonical_name @@ -24,6 +26,11 @@ class ArrayNamespaceInfo: ndx.uint32, ndx.uint64, ] + _defaults = { + "real floating": ndx.float64, + "integral": ndx.int64, + "indexing": ndx.int64, + } def capabilities(self) -> dict: return { @@ -40,7 +47,10 @@ def devices(self) -> list: def dtypes( self, *, device=None, kind: str | tuple[str, ...] | None = None ) -> dict[str, ndx.CoreType]: - # We don't care for device and don't use it. + # We don't care for device since we are writing ONNX graphs. + # We would rather not give users the impression that their arrays + # are tied to a specific device when serializing an ONNX graph as + # such a concept does not exist in the ONNX . out: dict[str, ndx.CoreType] = {} for dtype in self._all_array_api_types: if kind is None or ndx.isdtype(dtype, kind): @@ -50,11 +60,18 @@ def dtypes( def default_dtypes( self, *, device=None, kind: str | tuple[str, ...] | None ) -> dict[str, ndx.CoreType]: - return { - "real floating": ndx.float64, - "integral": ndx.int64, - "indexing": ndx.int64, - } + # See comment in `dtypes` method regarding device. + + kinds: Iterable[str] + + if kind is None: + kinds = self._defaults.keys() + elif isinstance(kind, str): + kinds = [kind] + else: + kinds = kind + + return {kind: self._defaults[kind] for kind in kinds} def __array_namespace_info__() -> ArrayNamespaceInfo: # noqa: N807