Skip to content

Commit

Permalink
fix eye
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Aug 20, 2024
1 parent d0dd5f4 commit 38d2bed
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
15 changes: 8 additions & 7 deletions xarray/namedarray/_array_api/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from xarray.namedarray._array_api._utils import (
_get_data_namespace,
_maybe_default_namespace,
# _maybe_default_namespace,
_get_namespace_dtype,
)
from xarray.namedarray._typing import (
Default,
Expand Down Expand Up @@ -53,7 +54,7 @@ def arange(
dtype: _DType | None = None,
device: _Device | None = None,
) -> NamedArray[_ShapeType, _DType]:
xp = _maybe_default_namespace()
xp = _get_namespace_dtype(dtype)
_data = xp.arange(start, stop=stop, step=step, dtype=dtype, device=device)
_dims = _infer_dims(_data.shape)
return NamedArray(_dims, _data)
Expand Down Expand Up @@ -159,7 +160,7 @@ def asarray(
def empty(
shape: _ShapeType, *, dtype: _DType | None = None, device: _Device | None = None
) -> NamedArray[_ShapeType, _DType]:
xp = _maybe_default_namespace()
xp = _get_namespace_dtype(dtype)
_data = xp.empty(shape, dtype=dtype, device=device)
_dims = _infer_dims(_data.shape)
return NamedArray(_dims, _data)
Expand Down Expand Up @@ -188,8 +189,8 @@ def eye(
dtype: _DType | None = None,
device: _Device | None = None,
) -> NamedArray[_ShapeType, _DType]:
xp = _maybe_default_namespace()
_data = xp.eye(n_rows, M=n_cols, k=k, dtype=dtype)
xp = _get_namespace_dtype(dtype)
_data = xp.eye(n_rows, n_cols, k=k, dtype=dtype, device=device)
_dims = _infer_dims(_data.shape)
return NamedArray(_dims, _data)

Expand All @@ -201,7 +202,7 @@ def full(
dtype: _DType | None = None,
device: _Device | None = None,
) -> NamedArray[_ShapeType, _DType]:
xp = _maybe_default_namespace()
xp = _get_namespace_dtype(dtype)
_data = xp.full(shape, fill_value, dtype=dtype, device=device)
_dims = _infer_dims(_data.shape)
return NamedArray(_dims, _data)
Expand Down Expand Up @@ -232,7 +233,7 @@ def linspace(
device: _Device | None = None,
endpoint: bool = True,
) -> NamedArray[_ShapeType, _DType]:
xp = _maybe_default_namespace()
xp = _get_namespace_dtype(dtype)
_data = xp.linspace(
start,
stop,
Expand Down
5 changes: 4 additions & 1 deletion xarray/namedarray/_array_api/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def _get_data_namespace(x: NamedArray[Any, Any]) -> ModuleType:
return _maybe_default_namespace()


def _get_namespace_dtype(dtype: _dtype) -> ModuleType:
def _get_namespace_dtype(dtype: _dtype | None = None) -> ModuleType:
if dtype is None:
return _maybe_default_namespace()

xp = __import__(dtype.__module__)
return xp

0 comments on commit 38d2bed

Please sign in to comment.