Skip to content

Commit

Permalink
Keep things DRY
Browse files Browse the repository at this point in the history
  • Loading branch information
cbourjau committed Sep 5, 2024
1 parent 77c85b2 commit 66ebb5e
Showing 1 changed file with 6 additions and 49 deletions.
55 changes: 6 additions & 49 deletions ndonnx/_logic_in_data/_typed_array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,47 +108,18 @@ def _where(

class TyArrayNumber(TyArray[CORE_DTYPES]):
def __add__(self, rhs: TyArrayBase) -> TyArrayBase:
if isinstance(rhs, TyArrayNumber):
# NOTE: Can't always promote for all data types (c.f. datetime / timedelta)
if type(self) != type(rhs):
a, b = promote(self, rhs)
return a + b
var = op.add(self.var, rhs.var)
return ascoredata(var)
return NotImplemented
return _promote_and_apply_op(self, rhs, operator.add, op.add)

def __mul__(self, rhs: TyArrayBase) -> TyArrayBase:
if isinstance(rhs, TyArrayNumber):
# NOTE: Can't always promote for all data types (c.f. datetime / timedelta)
if type(self) != type(rhs):
a, b = promote(self, rhs)
return a * b
var = op.mul(self.var, rhs.var)
return ascoredata(var)
return NotImplemented
return _promote_and_apply_op(self, rhs, operator.mul, op.mul)

def __sub__(self, rhs: TyArrayBase) -> TyArrayBase:
if isinstance(rhs, TyArrayNumber):
# NOTE: Can't always promote for all data types (c.f. datetime / timedelta)
if type(self) != type(rhs):
a, b = promote(self, rhs)
return a - b
var = op.add(self.var, rhs.var)
return ascoredata(var)
return NotImplemented
return _promote_and_apply_op(self, rhs, operator.sub, op.sub)


class TyArrayInteger(TyArrayNumber[CORE_DTYPES]):
def __or__(self, rhs: TyArrayBase) -> TyArrayBase:
if isinstance(rhs, TyArray):
if self.dtype != rhs.dtype:
a, b = promote(self, rhs)
return a | b

# Data is core & integer
var = op.bitwise_or(self.var, rhs.var)
return ascoredata(var)
return NotImplemented
return _promote_and_apply_op(self, rhs, operator.or_, op.bitwise_or)


class TyArrayFloating(TyArrayNumber[CORE_DTYPES]): ...
Expand All @@ -158,24 +129,10 @@ class TyArrayBool(TyArray[dtypes.Bool]):
dtype = dtypes.bool_

def __or__(self, rhs: TyArrayBase) -> TyArrayBase:
if self.dtype != rhs.dtype:
a, b = promote(self, rhs)
return a | b

if isinstance(rhs, TyArrayBool):
var = op.or_(self.var, rhs.var)
return ascoredata(var)
return NotImplemented
return _promote_and_apply_op(self, rhs, operator.or_, op.or_)

def __and__(self, rhs: TyArrayBase) -> TyArrayBase:
if self.dtype != rhs.dtype:
a, b = promote(self, rhs)
return a & b

if isinstance(rhs, TyArrayBool):
var = op.and_(self.var, rhs.var)
return ascoredata(var)
return NotImplemented
return _promote_and_apply_op(self, rhs, operator.and_, op.and_)

def __invert__(self) -> TyArrayBool:
var = op.not_(self.var)
Expand Down

0 comments on commit 66ebb5e

Please sign in to comment.