Skip to content

Commit

Permalink
__future__ annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Aug 29, 2024
1 parent 49a21de commit 7254338
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 19 deletions.
7 changes: 4 additions & 3 deletions ndonnx/_core/_coreimpl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

from typing import TYPE_CHECKING

Expand All @@ -23,9 +24,9 @@ class CoreOperationsImpl(OperationsBlock):
def make_array(
self,
shape: tuple[int | None | str, ...],
dtype: "Dtype",
dtype: Dtype,
eager_value: np.ndarray | None = None,
) -> "Array":
) -> Array:
if not isinstance(dtype, dtypes.CoreType):
return NotImplemented
return ndx.Array._from_fields(
Expand All @@ -38,7 +39,7 @@ def make_array(
)

@validate_core
def make_nullable(self, x: "Array", null: "Array") -> "Array":
def make_nullable(self, x: Array, null: Array) -> Array:
if null.dtype != ndx.bool:
raise TypeError("'null' must be a boolean array")

Expand Down
19 changes: 9 additions & 10 deletions ndonnx/_core/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

if TYPE_CHECKING:
from ndonnx._array import IndexType
from ndonnx._data_types import Dtype


class OperationsBlock:
Expand Down Expand Up @@ -256,7 +257,7 @@ def cumulative_sum(
x,
*,
axis: int | None = None,
dtype: ndx.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
include_initial: bool = False,
):
return NotImplemented
Expand All @@ -275,7 +276,7 @@ def prod(
x,
*,
axis=None,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
keepdims: bool = False,
) -> ndx.Array:
return NotImplemented
Expand All @@ -298,7 +299,7 @@ def sum(
x,
*,
axis=None,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
keepdims: bool = False,
) -> ndx.Array:
return NotImplemented
Expand All @@ -310,7 +311,7 @@ def var(
axis=None,
keepdims: bool = False,
correction=0.0,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
) -> ndx.Array:
return NotImplemented

Expand Down Expand Up @@ -357,7 +358,7 @@ def full_like(self, x, fill_value, dtype=None, device=None) -> ndx.Array:
def ones(
self,
shape,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
device=None,
):
return NotImplemented
Expand All @@ -370,14 +371,12 @@ def ones_like(
def zeros(
self,
shape,
dtype: dtypes.CoreType | dtypes.StructType | None = None,
dtype: Dtype | None = None,
device=None,
):
return NotImplemented

def zeros_like(
self, x, dtype: dtypes.CoreType | dtypes.StructType | None = None, device=None
):
def zeros_like(self, x, dtype: Dtype | None = None, device=None):
return NotImplemented

def empty(self, shape, dtype=None, device=None) -> ndx.Array:
Expand Down Expand Up @@ -422,7 +421,7 @@ def static_shape(self, x) -> tuple[int | None, ...]:
def make_array(
self,
shape: tuple[int | None | str, ...],
dtype: dtypes.CoreType | dtypes.StructType,
dtype: Dtype,
eager_value: np.ndarray | None = None,
) -> ndx.Array:
return NotImplemented
Expand Down
6 changes: 4 additions & 2 deletions ndonnx/_core/_nullableimpl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

from typing import TYPE_CHECKING, Union

import ndonnx as ndx
Expand All @@ -16,12 +18,12 @@

class NullableOperationsImpl(OperationsBlock):
@validate_core
def fill_null(self, x: "Array", value) -> "Array":
def fill_null(self, x: Array, value) -> Array:
value = ndx.asarray(value)
if value.dtype != x.values.dtype:
value = value.astype(x.values.dtype)
return ndx.where(x.null, value, x.values)

@validate_core
def make_nullable(self, x: "Array", null: "Array") -> "Array":
def make_nullable(self, x: Array, null: Array) -> Array:
return NotImplemented
10 changes: 6 additions & 4 deletions ndonnx/_core/_shapeimpl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -257,9 +259,9 @@ def ones_like(self, x, dtype=None, device=None):
def make_array(
self,
shape: tuple[int | None | str, ...],
dtype: "Dtype",
dtype: Dtype,
eager_value: np.ndarray | None = None,
) -> "Array":
) -> Array:
if isinstance(dtype, dtypes.CoreType):
return NotImplemented

Expand All @@ -281,7 +283,7 @@ def make_array(
**fields,
)

def getitem(self, x: "Array", index: "IndexType") -> "Array":
def getitem(self, x: Array, index: IndexType) -> Array:
if isinstance(index, ndx.Array) and not (
isinstance(index.dtype, dtypes.Integral) or index.dtype == dtypes.bool
):
Expand All @@ -295,7 +297,7 @@ def getitem(self, x: "Array", index: "IndexType") -> "Array":
return x._transmute(lambda corearray: corearray[index])


def _assemble_output_recurse(dtype: "Dtype", values: dict) -> np.ndarray:
def _assemble_output_recurse(dtype: Dtype, values: dict) -> np.ndarray:
if isinstance(dtype, dtypes.CoreType):
return dtype._assemble_output(values)
else:
Expand Down

0 comments on commit 7254338

Please sign in to comment.