Skip to content

Commit

Permalink
Play with array-api tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Sep 8, 2024
1 parent 4d896e2 commit 1d7ddc8
Show file tree
Hide file tree
Showing 9 changed files with 332 additions and 18 deletions.
58 changes: 55 additions & 3 deletions ndonnx/_logic_in_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,59 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from .array import Array
from .array import Array, asarray
from .dtypes import (
int8,
int16,
int32,
int64,
uint8,
uint16,
uint32,
uint64,
float16,
float32,
float64,
bool_ as bool,
DType,
)
from .funcs import (
arange,
ones,
finfo,
iinfo,
zeros,
reshape,
all,
isfinite,
isnan,
equal,
)


__all__ = ["Array"]
__all__ = [
"Array",
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float16",
"float32",
"float64",
"bool",
"DType",
"arange",
"ones",
"finfo",
"iinfo",
"zeros",
"reshape",
"asarray",
"all",
"isfinite",
"isnan",
"equal",
]
45 changes: 43 additions & 2 deletions ndonnx/_logic_in_data/_typed_array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,19 @@ def __getitem__(self, index: Index) -> Self:
ends=op.const([index + 1]),
axes=op.const([0]),
)
var = op.squeeze(var, axes=op.const(0))
var = op.squeeze(var, axes=op.const([0]))
return type(self)(var)
if isinstance(index, tuple) and all_items_are_int(index):
starts, ends, axes = zip(*[(el, el + 1, i) for i, el in enumerate(index)])
var = op.slice(
self.var,
starts=op.const(starts),
ends=op.const(ends),
axes=op.const(axes),
)
var = op.squeeze(var, axes=op.const(axes))
return type(self)(var)

raise NotImplementedError

@property
Expand All @@ -67,6 +78,26 @@ def to_numpy(self) -> np.ndarray:
return np_arr
raise ValueError("no propagated value available")

def all(
self, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False
) -> TyArrayBool:
if isinstance(axis, int):
axis = (axis,)

bools = self.astype(dtypes.bool_)

if bools.ndim == 0:
if axis:
ValueError("'axis' were provided but 'self' is a scalar")
# Nothing left to reduce
return safe_cast(TyArrayBool, bools)

axes = op.const(list(axis)) if axis else None

# max int8 is returned if dimensions are empty
var = op.reduce_min(bools.astype(dtypes.int8).var, axes=axes, keepdims=keepdims)
return safe_cast(TyArrayBool, TyArrayInt8(var).astype(dtypes.bool_))

def disassemble(self) -> tuple[Components, Schema]:
dtype_info = self.dtype._info
component_schema = var_to_primitive(self.var)
Expand All @@ -75,7 +106,11 @@ def disassemble(self) -> tuple[Components, Schema]:
return components, schema

def reshape(self, shape: tuple[int, ...]) -> Self:
var = op.reshape(self.var, op.const(shape))
var = op.reshape(self.var, op.const(shape), allowzero=True)
return type(self)(var)

def broadcast_to(self, shape: tuple[int, ...]) -> Self:
var = op.expand(self.var, op.const(shape, dtype=np.int64))
return type(self)(var)

def as_core_dtype(self, dtype: CoreDTypes) -> TyArray:
Expand Down Expand Up @@ -280,6 +315,12 @@ def is_sequence_of_core_data(
return all(isinstance(d, TyArray) for d in seq)


def all_items_are_int(
seq: Sequence,
) -> TypeGuard[Sequence[int]]:
return all(isinstance(d, int) for d in seq)


def _promote_and_apply_op(
this: TyArray,
other: TyArrayBase,
Expand Down
5 changes: 5 additions & 0 deletions ndonnx/_logic_in_data/_typed_array/date_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def reshape(self, shape: tuple[int, ...]) -> Self:

return type(self)(is_nat=is_nat, data=data, unit=self.dtype.unit)

def broadcast_to(self, shape: tuple[int, ...]) -> Self:
data = self.data.broadcast_to(shape)
is_nat = self.is_nat.broadcast_to(shape)
return type(self)(data=data, is_nat=is_nat, unit=self.dtype.unit)

def _eqcomp(self, other) -> TyArrayBase:
raise NotImplementedError

Expand Down
5 changes: 5 additions & 0 deletions ndonnx/_logic_in_data/_typed_array/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def reshape(self, shape: tuple[int, ...]) -> Self:
mask = self.mask.reshape(shape) if self.mask is not None else None
return type(self)(data=data, mask=mask)

def broadcast_to(self, shape: tuple[int, ...]) -> Self:
data = self.data.broadcast_to(shape)
mask = self.mask.broadcast_to(shape) if self.mask else None
return type(self)(data=data, mask=mask)

def __getitem__(self, index: Index) -> Self:
new_data = self.data[index]
new_mask = self.mask[index] if self.mask is not None else None
Expand Down
3 changes: 3 additions & 0 deletions ndonnx/_logic_in_data/_typed_array/py_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def shape(self) -> OnnxShape:
def reshape(self, shape: tuple[int, ...]) -> Self:
raise ValueError("cannot reshape Python scalar")

def broadcast_to(self, shape: tuple[int, ...]) -> Self:
raise ValueError("cannot broadcast Python scalar")

def __add__(self, rhs: TyArrayBase) -> TyArrayBase:
return _promote_and_apply_op(self, rhs, operator.add, True)

Expand Down
12 changes: 12 additions & 0 deletions ndonnx/_logic_in_data/_typed_array/typed_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ def _astype(self, dtype: DType) -> TyArrayBase | NotImplementedType:
"""
return NotImplemented

def all(self) -> TyArrayBase:
raise ValueError(f"'all' is not implemented for `{self.dtype}`")

@abstractmethod
def broadcast_to(self, shape: tuple[int, ...]) -> Self: ...

def isnan(self) -> TyArrayBase:
raise ValueError(f"'isnan' is not implemented for {self.dtype}")

def isfinite(self) -> TyArrayBase:
raise ValueError(f"'isinfinite' is not implemented for {self.dtype}")

def _where(
self, cond: TyArrayBool, y: TyArrayBase
) -> TyArrayBase | NotImplementedType:
Expand Down
26 changes: 19 additions & 7 deletions ndonnx/_logic_in_data/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .dtypes import DType

StrictShape = tuple[int, ...]
StandardShape = tuple[int | None, ...]
StandardShape = int | tuple[int | None, ...]
OnnxShape = tuple[int | str | None, ...]

ScalarIndex = int | bool | slice | EllipsisType | None
Expand Down Expand Up @@ -126,16 +126,18 @@ def __array_namespace__(
raise NotImplementedError

def __bool__(self: Array, /) -> bool:
raise NotImplementedError
# TODO: If we want to do this, it should be on the typed array!
return bool(self.unwrap_numpy())

def __complex__(self: Array, /) -> complex:
raise NotImplementedError

def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array: # type: ignore
raise NotImplementedError
return _apply_op(self, other, std_ops.eq)

def __float__(self: Array, /) -> float:
raise NotImplementedError
# TODO: If we want to do this, it should be on the typed array!
return float(self.unwrap_numpy())

def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
raise NotImplementedError
Expand All @@ -150,7 +152,8 @@ def __index__(self: Array, /) -> int:
raise NotImplementedError

def __int__(self: Array, /) -> int:
raise NotImplementedError
# TODO: If we want to do this, it should be on the typed array!
return int(self.unwrap_numpy())

def __invert__(self: Array, /) -> Array:
raise NotImplementedError
Expand Down Expand Up @@ -218,10 +221,19 @@ def __xor__(self: Array, other: Union[int, bool, Array], /) -> Array:
raise NotImplementedError


def asarray(obj: int | float | bool | str | Array | np.ndarray) -> Array:
def asarray(
obj: Array | bool | int | float | np.ndarray,
/,
*,
dtype: DType | None = None,
device=None,
copy: bool | None = None,
) -> Array:
if isinstance(obj, Array):
return obj
data = ascoredata(op.const(obj))
data: TyArrayBase = ascoredata(op.const(obj))
if dtype:
data = data.astype(dtype)
return Array._from_data(data)


Expand Down
Loading

0 comments on commit 1d7ddc8

Please sign in to comment.