Skip to content

Commit

Permalink
Fix truediv for time deltas
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Jan 24, 2025
1 parent f2ffc1c commit dfd3073
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
20 changes: 16 additions & 4 deletions ndonnx/_refactor/_typed_array/date_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,23 @@ def __sub__(self, rhs: TyArrayBase | _PyScalar) -> TyArrayTimeDelta:
def __rsub__(self, lhs: TyArrayBase | _PyScalar) -> TyArrayTimeDelta:
return _apply_op(self, lhs, operator.sub, False)

def __truediv__(self, rhs: TyArrayBase | _PyScalar) -> TyArrayTimeDelta:
return _apply_op(self, rhs, operator.truediv, True)
def __truediv__(self, rhs: TyArrayBase | _PyScalar) -> TyArrayBase:
if isinstance(rhs, onnx.TyArrayNumber | float | int):
data = (self.data / astyarray(rhs)).astype(onnx.int64)
return TyArrayTimeDelta(is_nat=self.is_nat, data=data, unit=self.dtype.unit)
if isinstance(rhs, TyArrayTimeDelta) and self.dtype == rhs.dtype:
res = (self.data / rhs.data).astype(onnx.float64)
res[safe_cast(onnx.TyArrayBool, self.is_nat | rhs.is_nat)] = astyarray(
np.nan, dtype=onnx.float64
)
return res
return NotImplemented

def __rtruediv__(self, lhs: TyArrayBase | _PyScalar) -> TyArrayTimeDelta:
return _apply_op(self, lhs, operator.truediv, False)
def __rtruediv__(self, lhs: TyArrayBase | _PyScalar) -> TyArrayBase:
if isinstance(lhs, onnx.TyArrayNumber | float | int):
data = (astyarray(lhs) / self.data).astype(onnx.int64)
return TyArrayTimeDelta(is_nat=self.is_nat, data=data, unit=self.dtype.unit)
return NotImplemented

def _eqcomp(self, other) -> onnx.TyArrayBool:
if not isinstance(other.dtype, TimeDelta):
Expand Down
5 changes: 4 additions & 1 deletion ndonnx/_refactor/extensions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -171,6 +171,9 @@ def make_nullable(
TypeError
If the data type of ``x`` does not have a nullable counterpart.
"""
x = x.copy()
null = None if null is None else null.copy()

if null is None:
if isinstance(x, TyMaArray):
return x
Expand Down

0 comments on commit dfd3073

Please sign in to comment.