From 2e6f3da8ca0397c5a4afd9f503215086396439e7 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Wed, 21 Aug 2024 22:57:36 +0200 Subject: [PATCH] handle scalars --- xarray/namedarray/core.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index af55abefdca..33555d40275 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -668,10 +668,10 @@ def __pos__(self, /): return positive(self) - def __add__(self, other, /): - from xarray.namedarray._array_api import add + def __add__(self, other: int | float | NamedArray, /) -> NamedArray: + from xarray.namedarray._array_api import add, asarray - return add(self, other) + return add(self, asarray(other)) def __sub__(self, other, /): from xarray.namedarray._array_api import subtract @@ -743,11 +743,10 @@ def __rshift__(self, other, /): return bitwise_right_shift(self) # Comparison Operators + def __eq__(self, other: int | float | bool | NamedArray, /) -> NamedArray: + from xarray.namedarray._array_api import equal, asarray - def __eq__(self, other, /): - from xarray.namedarray._array_api import equal - - return equal(self, other) + return equal(self, asarray(other)) def __ge__(self, other, /): from xarray.namedarray._array_api import greater_equal