diff --git a/xarray/namedarray/_array_api/creation_functions.py b/xarray/namedarray/_array_api/creation_functions.py index 957f11e6940..f0588fba8d6 100644 --- a/xarray/namedarray/_array_api/creation_functions.py +++ b/xarray/namedarray/_array_api/creation_functions.py @@ -4,12 +4,16 @@ import numpy as np -from xarray.namedarray._array_api._utils import _maybe_default_namespace +from xarray.namedarray._array_api._utils import ( + _maybe_default_namespace, + _get_data_namespace, +) from xarray.namedarray._typing import ( Default, _arrayfunction_or_api, _ArrayLike, _default, + _Device, _DimsLike, _DType, _Shape, @@ -24,6 +28,12 @@ ) +def _like_args(x, dtype=None, device: _Device | None = None): + if dtype is None: + dtype = x.dtype + return dict(shape=x.shape, dtype=dtype, device=device) + + def _infer_dims( shape: _Shape, dims: _DimsLike | Default = _default, @@ -41,7 +51,7 @@ def arange( step: int | float = 1, *, dtype: _DType | None = None, - device=None, + device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: xp = _maybe_default_namespace() _data = xp.arange(start, stop=stop, step=step, dtype=dtype, device=device) @@ -55,7 +65,7 @@ def asarray( /, *, dtype: _DType, - device=..., + device: _Device | None = ..., copy: bool | None = ..., dims: _DimsLike = ..., ) -> NamedArray[_ShapeType, _DType]: ... @@ -65,7 +75,7 @@ def asarray( /, *, dtype: _DType, - device=..., + device: _Device | None = ..., copy: bool | None = ..., dims: _DimsLike = ..., ) -> NamedArray[Any, _DType]: ... @@ -75,7 +85,7 @@ def asarray( /, *, dtype: None, - device=None, + device: _Device | None = None, copy: bool | None = None, dims: _DimsLike = ..., ) -> NamedArray[_ShapeType, _DType]: ... @@ -85,7 +95,7 @@ def asarray( /, *, dtype: None, - device=..., + device: _Device | None = ..., copy: bool | None = ..., dims: _DimsLike = ..., ) -> NamedArray[Any, _DType]: ... @@ -94,7 +104,7 @@ def asarray( /, *, dtype: _DType | None = None, - device=None, + device: _Device | None = None, copy: bool | None = None, dims: _DimsLike = _default, ) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, Any]: @@ -146,12 +156,35 @@ def asarray( return NamedArray(_dims, _data) +def empty( + shape: _ShapeType, *, dtype: _DType | None = None, device: _Device | None = None +) -> NamedArray[_ShapeType, _DType]: + xp = _maybe_default_namespace() + _data = xp.empty(shape, dtype=dtype, device=device) + _dims = _infer_dims(_data.shape) + return NamedArray(_dims, _data) + + +def empty_like( + x: NamedArray[_ShapeType, _DType], + /, + *, + dtype: _DType | None = None, + device: _Device | None = None, +) -> NamedArray[_ShapeType, _DType]: + xp = _get_data_namespace(x) + _dtype = x.dtype if dtype is None else dtype + _device = x.device if device is None else device + _data = xp.empty(x.shape, dtype=_dtype, device=_device) + return x._new(data=_data) + + def full( shape: _Shape, fill_value: bool | int | float | complex, *, dtype: _DType | None = None, - device=None, + device: _Device | None = None, ) -> NamedArray[_ShapeType, _DType]: xp = _maybe_default_namespace() _data = xp.full(shape, fill_value, dtype=dtype, device=device) @@ -159,6 +192,21 @@ def full( return NamedArray(_dims, _data) +def full_like( + x: NamedArray[_ShapeType, _DType], + fill_value: bool | int | float | complex, + /, + *, + dtype: _DType | None = None, + device: _Device | None = None, +) -> NamedArray[_ShapeType, _DType]: + xp = _get_data_namespace(x) + _dtype = x.dtype if dtype is None else dtype + _device = x.device if device is None else device + _data = xp.full(x.shape, fill_value, dtype=_dtype, device=_device) + return x._new(data=_data) + + def linspace( start: int | float | complex, stop: int | float | complex, @@ -166,7 +214,7 @@ def linspace( num: int, *, dtype: _DType | None = None, - device=None, + device: _Device | None = None, endpoint: bool = True, ) -> NamedArray[_ShapeType, _DType]: xp = _maybe_default_namespace() @@ -178,12 +226,40 @@ def linspace( def ones( - shape: _Shape, *, dtype: _DType | None = None, device=None + shape: _Shape, *, dtype: _DType | None = None, device: _Device | None = None ) -> NamedArray[_ShapeType, _DType]: return full(shape, 1, dtype=dtype, device=device) +def ones_like( + x: NamedArray[_ShapeType, _DType], + /, + *, + dtype: _DType | None = None, + device: _Device | None = None, +) -> NamedArray[_ShapeType, _DType]: + xp = _get_data_namespace(x) + _dtype = x.dtype if dtype is None else dtype + _device = x.device if device is None else device + _data = xp.ones(x.shape, dtype=_dtype, device=_device) + return x._new(data=_data) + + def zeros( - shape: _Shape, *, dtype: _DType | None = None, device=None + shape: _Shape, *, dtype: _DType | None = None, device: _Device | None = None ) -> NamedArray[_ShapeType, _DType]: return full(shape, 0, dtype=dtype, device=device) + + +def zeros_like( + x: NamedArray[_ShapeType, _DType], + /, + *, + dtype: _DType | None = None, + device: _Device | None = None, +) -> NamedArray[_ShapeType, _DType]: + xp = _get_data_namespace(x) + _dtype = x.dtype if dtype is None else dtype + _device = x.device if device is None else device + _data = xp.zeros(x.shape, dtype=_dtype, device=_device) + return x._new(data=_data)