diff --git a/xarray/namedarray/_array_api/_creation_functions.py b/xarray/namedarray/_array_api/_creation_functions.py index 1ea8418b49c..196c5c9b9f0 100644 --- a/xarray/namedarray/_array_api/_creation_functions.py +++ b/xarray/namedarray/_array_api/_creation_functions.py @@ -6,7 +6,8 @@ from xarray.namedarray._array_api._utils import ( _get_data_namespace, - _maybe_default_namespace, + # _maybe_default_namespace, + _get_namespace_dtype, ) from xarray.namedarray._typing import ( Default, @@ -53,7 +54,7 @@ def arange( dtype: _DType | None = None, device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: - xp = _maybe_default_namespace() + xp = _get_namespace_dtype(dtype) _data = xp.arange(start, stop=stop, step=step, dtype=dtype, device=device) _dims = _infer_dims(_data.shape) return NamedArray(_dims, _data) @@ -159,7 +160,7 @@ def asarray( def empty( shape: _ShapeType, *, dtype: _DType | None = None, device: _Device | None = None ) -> NamedArray[_ShapeType, _DType]: - xp = _maybe_default_namespace() + xp = _get_namespace_dtype(dtype) _data = xp.empty(shape, dtype=dtype, device=device) _dims = _infer_dims(_data.shape) return NamedArray(_dims, _data) @@ -188,8 +189,8 @@ def eye( dtype: _DType | None = None, device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: - xp = _maybe_default_namespace() - _data = xp.eye(n_rows, M=n_cols, k=k, dtype=dtype) + xp = _get_namespace_dtype(dtype) + _data = xp.eye(n_rows, n_cols, k=k, dtype=dtype, device=device) _dims = _infer_dims(_data.shape) return NamedArray(_dims, _data) @@ -201,7 +202,7 @@ def full( dtype: _DType | None = None, device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: - xp = _maybe_default_namespace() + xp = _get_namespace_dtype(dtype) _data = xp.full(shape, fill_value, dtype=dtype, device=device) _dims = _infer_dims(_data.shape) return NamedArray(_dims, _data) @@ -232,7 +233,7 @@ def linspace( device: _Device | None = None, endpoint: bool = True, ) -> NamedArray[_ShapeType, _DType]: - xp = _maybe_default_namespace() + xp = _get_namespace_dtype(dtype) _data = xp.linspace( start, stop, diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 8a296ee552e..018b84ca54c 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -28,6 +28,9 @@ def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType: return _maybe_default_namespace() -def _get_namespace_dtype(dtype: _dtype) -> ModuleType: +def _get_namespace_dtype(dtype: _dtype | None = None) -> ModuleType: + if dtype is None: + return _maybe_default_namespace() + xp = __import__(dtype.__module__) return xp