Skip to content

Commit

Permalink
Implement various constructor functions
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Sep 8, 2024
1 parent 1d7ddc8 commit 342ee63
Show file tree
Hide file tree
Showing 9 changed files with 249 additions and 80 deletions.
31 changes: 24 additions & 7 deletions ndonnx/_logic_in_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,23 @@
)
from .funcs import (
arange,
empty,
empty_like,
ones,
ones_like,
full,
full_like,
finfo,
iinfo,
zeros,
reshape,
all,
isfinite,
isnan,
equal,
zeros,
zeros_like,
linspace,
where,
)

__all__ = [
Expand All @@ -45,15 +53,24 @@
"float64",
"bool",
"DType",
"all",
"arange",
"ones",
"asarray",
"empty",
"empty_like",
"equal",
"finfo",
"full",
"full_like",
"iinfo",
"zeros",
"reshape",
"asarray",
"all",
"isfinite",
"isnan",
"equal",
"linspace",
"ones",
"ones_like",
"reshape",
"where",
"zeros",
"zeros",
"zeros_like",
]
97 changes: 74 additions & 23 deletions ndonnx/_logic_in_data/_typed_array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,57 @@
)


class _Index:
starts: list[int]
ends: list[int]
steps: list[int]
axes: list[int]
squeeze_axes: list[int]

def __init__(self, index: Index):
if isinstance(index, tuple):
index_ = index
elif isinstance(index, int | slice):
index_ = (index,)
else:
raise NotImplementedError
self.starts = []
self.ends = []
self.steps = []
self.axes = []
self.squeeze_axes = []

def compute_end_slice(stop: int | None, step: int | None) -> int:
if isinstance(stop, int):
return stop
step = step or 1
# Iterate "to the end"
if step < 1:
return int(np.iinfo(np.int64).min)
return int(np.iinfo(np.int64).max)

def compute_end_single_idx(start: int):
end = start + 1
if end == 0:
return np.iinfo(np.int64).max
return end

for i, el in enumerate(index_):
if isinstance(el, slice):
self.starts.append(el.start or 0)
self.ends.append(compute_end_slice(el.stop, el.step))
self.axes.append(i)
self.steps.append(el.step or 1)
elif isinstance(el, int):
self.starts.append(el)
self.ends.append(compute_end_single_idx(el))
self.axes.append(i)
self.steps.append(1)
self.squeeze_axes.append(i)
else:
raise NotImplementedError


class TyArray(TyArrayBase):
dtype: dtypes.CoreDTypes
var: Var
Expand All @@ -40,27 +91,18 @@ def __init__(self, var: Var):
self.var = var

def __getitem__(self, index: Index) -> Self:
if isinstance(index, int):
var = op.slice(
self.var,
starts=op.const([index]),
ends=op.const([index + 1]),
axes=op.const([0]),
)
var = op.squeeze(var, axes=op.const([0]))
return type(self)(var)
if isinstance(index, tuple) and all_items_are_int(index):
starts, ends, axes = zip(*[(el, el + 1, i) for i, el in enumerate(index)])
var = op.slice(
self.var,
starts=op.const(starts),
ends=op.const(ends),
axes=op.const(axes),
)
var = op.squeeze(var, axes=op.const(axes))
return type(self)(var)

raise NotImplementedError
if index == ():
return self
parsed = _Index(index)
var = op.slice(
self.var,
starts=op.const(parsed.starts),
ends=op.const(parsed.ends),
axes=op.const(parsed.axes),
steps=op.const(parsed.steps),
)
var = op.squeeze(var, axes=op.const(parsed.squeeze_axes, np.int64))
return type(self)(var)

@property
def shape(self) -> OnnxShape:
Expand All @@ -69,6 +111,11 @@ def shape(self) -> OnnxShape:
raise ValueError("Missing shape information")
return shape

@property
def dynamic_shape(self) -> TyArrayInt64:
var = op.shape(self.var)
return TyArrayInt64(var)

def to_numpy(self) -> np.ndarray:
if self.var._value is not None:
np_arr = np.asarray(self.var._value.value)
Expand Down Expand Up @@ -109,8 +156,12 @@ def reshape(self, shape: tuple[int, ...]) -> Self:
var = op.reshape(self.var, op.const(shape), allowzero=True)
return type(self)(var)

def broadcast_to(self, shape: tuple[int, ...]) -> Self:
var = op.expand(self.var, op.const(shape, dtype=np.int64))
def broadcast_to(self, shape: tuple[int, ...] | TyArrayInt64) -> Self:
if isinstance(shape, tuple):
shape_var = op.const(shape, dtype=np.int64)
else:
shape_var = shape.var
var = op.expand(self.var, shape_var)
return type(self)(var)

def as_core_dtype(self, dtype: CoreDTypes) -> TyArray:
Expand Down
6 changes: 5 additions & 1 deletion ndonnx/_logic_in_data/_typed_array/date_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,17 @@ def __getitem__(self, index: Index) -> Self:
def shape(self) -> OnnxShape:
return self.data.shape

@property
def dynamic_shape(self) -> TyArrayInt64:
return self.data.dynamic_shape

def reshape(self, shape: tuple[int, ...]) -> Self:
is_nat = self.is_nat.reshape(shape)
data = self.data.reshape(shape)

return type(self)(is_nat=is_nat, data=data, unit=self.dtype.unit)

def broadcast_to(self, shape: tuple[int, ...]) -> Self:
def broadcast_to(self, shape: tuple[int, ...] | TyArrayInt64) -> Self:
data = self.data.broadcast_to(shape)
is_nat = self.is_nat.broadcast_to(shape)
return type(self)(data=data, is_nat=is_nat, unit=self.dtype.unit)
Expand Down
13 changes: 7 additions & 6 deletions ndonnx/_logic_in_data/_typed_array/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .. import dtypes
from ..dtypes import CoreDTypes, DType, NCoreDTypes
from ..schema import Schema, flatten_components
from .core import TyArray, TyArrayBool
from .core import TyArray, TyArrayBool, TyArrayInt64
from .typed_array import TyArrayBase

if TYPE_CHECKING:
Expand Down Expand Up @@ -82,17 +82,18 @@ def __init__(self, data: TyArray, mask: TyArrayBool | None):

@property
def shape(self) -> OnnxShape:
shape = self.data.shape
if shape is None:
raise ValueError("Missing shape information")
return shape
return self.data.shape

@property
def dynamic_shape(self) -> TyArrayInt64:
return self.data.dynamic_shape

def reshape(self, shape: tuple[int, ...]) -> Self:
data = self.data.reshape(shape)
mask = self.mask.reshape(shape) if self.mask is not None else None
return type(self)(data=data, mask=mask)

def broadcast_to(self, shape: tuple[int, ...]) -> Self:
def broadcast_to(self, shape: tuple[int, ...] | TyArrayInt64) -> Self:
data = self.data.broadcast_to(shape)
mask = self.mask.broadcast_to(shape) if self.mask else None
return type(self)(data=data, mask=mask)
Expand Down
8 changes: 6 additions & 2 deletions ndonnx/_logic_in_data/_typed_array/py_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
if TYPE_CHECKING:
from ..array import OnnxShape
from ..schema import Components, Schema
from .core import TyArrayBool
from .core import TyArrayBool, TyArrayInt64


class _ArrayPyScalar(TyArrayBase):
Expand Down Expand Up @@ -54,10 +54,14 @@ def ndim(self) -> int:
def shape(self) -> OnnxShape:
return ()

@property
def dynamic_shape(self) -> TyArrayInt64:
raise ValueError("'dynamic_shape' should never be called on Python scalar")

def reshape(self, shape: tuple[int, ...]) -> Self:
raise ValueError("cannot reshape Python scalar")

def broadcast_to(self, shape: tuple[int, ...]) -> Self:
def broadcast_to(self, shape: tuple[int, ...] | TyArrayInt64) -> Self:
raise ValueError("cannot broadcast Python scalar")

def __add__(self, rhs: TyArrayBase) -> TyArrayBase:
Expand Down
8 changes: 6 additions & 2 deletions ndonnx/_logic_in_data/_typed_array/typed_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
if TYPE_CHECKING:
from ..array import Index, OnnxShape
from ..schema import Components, Schema
from .core import TyArray, TyArrayBool
from .core import TyArray, TyArrayBool, TyArrayInt64


class TyArrayBase(ABC):
Expand All @@ -34,6 +34,10 @@ def __getitem__(self, index: Index) -> Self: ...
@abstractmethod
def shape(self) -> OnnxShape: ...

@property
@abstractmethod
def dynamic_shape(self) -> TyArrayInt64: ...

@abstractmethod
def reshape(self, shape: tuple[int, ...]) -> Self: ...

Expand Down Expand Up @@ -81,7 +85,7 @@ def all(self) -> TyArrayBase:
raise ValueError(f"'all' is not implemented for `{self.dtype}`")

@abstractmethod
def broadcast_to(self, shape: tuple[int, ...]) -> Self: ...
def broadcast_to(self, shape: tuple[int, ...] | TyArrayInt64) -> Self: ...

def isnan(self) -> TyArrayBase:
raise ValueError(f"'isnan' is not implemented for {self.dtype}")
Expand Down
18 changes: 10 additions & 8 deletions ndonnx/_logic_in_data/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,19 @@ def _from_data(cls, data: TyArrayBase) -> Array:
return inst

@property
def shape(self) -> StandardShape:
def shape(self) -> tuple[int | None, ...]:
shape = self._data.shape
return tuple(None if isinstance(item, str) else item for item in shape)

@property
def ndim(self) -> int:
return len(self.shape)

@property
def dynamic_shape(self) -> Array:
shape = self._data.dynamic_shape
return Array._from_data(shape)

@property
def dtype(self) -> DType:
return self._data.dtype
Expand Down Expand Up @@ -237,13 +246,6 @@ def asarray(
return Array._from_data(data)


def where(cond: Array, a: Array, b: Array) -> Array:
from ._typed_array.funcs import typed_where

data = typed_where(cond._data, a._data, b._data)
return Array._from_data(data)


def add(a: Array, b: Array) -> Array:
return a + b

Expand Down
Loading

0 comments on commit 342ee63

Please sign in to comment.