Skip to content

Commit

Permalink
Add more creation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Illviljan committed Aug 19, 2024
1 parent c32abe5 commit 1855f7f
Showing 1 changed file with 87 additions and 11 deletions.
98 changes: 87 additions & 11 deletions xarray/namedarray/_array_api/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@

import numpy as np

from xarray.namedarray._array_api._utils import _maybe_default_namespace
from xarray.namedarray._array_api._utils import (
_maybe_default_namespace,
_get_data_namespace,
)
from xarray.namedarray._typing import (
Default,
_arrayfunction_or_api,
_ArrayLike,
_default,
_Device,
_DimsLike,
_DType,
_Shape,
Expand All @@ -24,6 +28,12 @@
)


def _like_args(x, dtype=None, device: _Device | None = None):
if dtype is None:
dtype = x.dtype
return dict(shape=x.shape, dtype=dtype, device=device)


def _infer_dims(
shape: _Shape,
dims: _DimsLike | Default = _default,
Expand All @@ -41,7 +51,7 @@ def arange(
step: int | float = 1,
*,
dtype: _DType | None = None,
device=None,
device: _Device | None = None,
) -> NamedArray[_ShapeType, _DType]:
xp = _maybe_default_namespace()
_data = xp.arange(start, stop=stop, step=step, dtype=dtype, device=device)
Expand All @@ -55,7 +65,7 @@ def asarray(
/,
*,
dtype: _DType,
device=...,
device: _Device | None = ...,
copy: bool | None = ...,
dims: _DimsLike = ...,
) -> NamedArray[_ShapeType, _DType]: ...
Expand All @@ -65,7 +75,7 @@ def asarray(
/,
*,
dtype: _DType,
device=...,
device: _Device | None = ...,
copy: bool | None = ...,
dims: _DimsLike = ...,
) -> NamedArray[Any, _DType]: ...
Expand All @@ -75,7 +85,7 @@ def asarray(
/,
*,
dtype: None,
device=None,
device: _Device | None = None,
copy: bool | None = None,
dims: _DimsLike = ...,
) -> NamedArray[_ShapeType, _DType]: ...
Expand All @@ -85,7 +95,7 @@ def asarray(
/,
*,
dtype: None,
device=...,
device: _Device | None = ...,
copy: bool | None = ...,
dims: _DimsLike = ...,
) -> NamedArray[Any, _DType]: ...
Expand All @@ -94,7 +104,7 @@ def asarray(
/,
*,
dtype: _DType | None = None,
device=None,
device: _Device | None = None,
copy: bool | None = None,
dims: _DimsLike = _default,
) -> NamedArray[_ShapeType, _DType] | NamedArray[Any, Any]:
Expand Down Expand Up @@ -146,27 +156,65 @@ def asarray(
return NamedArray(_dims, _data)


def empty(
shape: _ShapeType, *, dtype: _DType | None = None, device: _Device | None = None
) -> NamedArray[_ShapeType, _DType]:
xp = _maybe_default_namespace()
_data = xp.empty(shape, dtype=dtype, device=device)
_dims = _infer_dims(_data.shape)
return NamedArray(_dims, _data)


def empty_like(
x: NamedArray[_ShapeType, _DType],
/,
*,
dtype: _DType | None = None,
device: _Device | None = None,
) -> NamedArray[_ShapeType, _DType]:
xp = _get_data_namespace(x)
_dtype = x.dtype if dtype is None else dtype
_device = x.device if device is None else device
_data = xp.empty(x.shape, dtype=_dtype, device=_device)
return x._new(data=_data)


def full(
shape: _Shape,
fill_value: bool | int | float | complex,
*,
dtype: _DType | None = None,
device=None,
device: _Device | None = None,
) -> NamedArray[_ShapeType, _DType]:
xp = _maybe_default_namespace()
_data = xp.full(shape, fill_value, dtype=dtype, device=device)
_dims = _infer_dims(_data.shape)
return NamedArray(_dims, _data)


def full_like(
x: NamedArray[_ShapeType, _DType],
fill_value: bool | int | float | complex,
/,
*,
dtype: _DType | None = None,
device: _Device | None = None,
) -> NamedArray[_ShapeType, _DType]:
xp = _get_data_namespace(x)
_dtype = x.dtype if dtype is None else dtype
_device = x.device if device is None else device
_data = xp.full(x.shape, fill_value, dtype=_dtype, device=_device)
return x._new(data=_data)


def linspace(
start: int | float | complex,
stop: int | float | complex,
/,
num: int,
*,
dtype: _DType | None = None,
device=None,
device: _Device | None = None,
endpoint: bool = True,
) -> NamedArray[_ShapeType, _DType]:
xp = _maybe_default_namespace()
Expand All @@ -178,12 +226,40 @@ def linspace(


def ones(
shape: _Shape, *, dtype: _DType | None = None, device=None
shape: _Shape, *, dtype: _DType | None = None, device: _Device | None = None
) -> NamedArray[_ShapeType, _DType]:
return full(shape, 1, dtype=dtype, device=device)


def ones_like(
x: NamedArray[_ShapeType, _DType],
/,
*,
dtype: _DType | None = None,
device: _Device | None = None,
) -> NamedArray[_ShapeType, _DType]:
xp = _get_data_namespace(x)
_dtype = x.dtype if dtype is None else dtype
_device = x.device if device is None else device
_data = xp.ones(x.shape, dtype=_dtype, device=_device)
return x._new(data=_data)


def zeros(
shape: _Shape, *, dtype: _DType | None = None, device=None
shape: _Shape, *, dtype: _DType | None = None, device: _Device | None = None
) -> NamedArray[_ShapeType, _DType]:
return full(shape, 0, dtype=dtype, device=device)


def zeros_like(
x: NamedArray[_ShapeType, _DType],
/,
*,
dtype: _DType | None = None,
device: _Device | None = None,
) -> NamedArray[_ShapeType, _DType]:
xp = _get_data_namespace(x)
_dtype = x.dtype if dtype is None else dtype
_device = x.device if device is None else device
_data = xp.zeros(x.shape, dtype=_dtype, device=_device)
return x._new(data=_data)

0 comments on commit 1855f7f

Please sign in to comment.