From 60aa96989961661955597b485850bcc912146e8f Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Sun, 1 Sep 2024 16:49:11 +0200 Subject: [PATCH] Update _utils.py --- xarray/namedarray/_array_api/_utils.py | 84 +++++++++++++++----------- 1 file changed, 50 insertions(+), 34 deletions(-) diff --git a/xarray/namedarray/_array_api/_utils.py b/xarray/namedarray/_array_api/_utils.py index 9678640596f..c6975960024 100644 --- a/xarray/namedarray/_array_api/_utils.py +++ b/xarray/namedarray/_array_api/_utils.py @@ -1,6 +1,8 @@ from __future__ import annotations from collections.abc import Iterable +from itertools import zip_longest +import math from types import ModuleType from typing import TYPE_CHECKING, Any @@ -68,8 +70,21 @@ def _infer_dims( shape: _Shape, dims: _DimsLike | Default = _default, ) -> _DimsLike: + """ + Create default dim names if no dims were supplied. + + Examples + -------- + >>> _infer_dims(()) + () + >>> _infer_dims((1,)) + ('dim_0',) + >>> _infer_dims((3, 1)) + ('dim_1', 'dim_0') + """ if dims is _default: - return tuple(f"dim_{n}" for n in range(len(shape))) + ndim = len(shape) + return tuple(f"dim_{ndim - 1 - n}" for n in range(ndim)) else: return dims @@ -199,6 +214,11 @@ def _raise_if_any_duplicate_dimensions( ) +def _isnone(shape: _Shape) -> tuple[bool, ...]: + # TODO: math.isnan should not be needed for array api, but dask still uses np.nan: + return tuple(v is None and math.isnan(v) for v in shape) + + def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: """ Get the expected broadcasted dims. @@ -209,6 +229,8 @@ def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: >>> b = NamedArray(("y", "z"), np.zeros((3, 4))) >>> _get_broadcasted_dims(a, b) (('x', 'y', 'z'), (5, 3, 4)) + >>> _get_broadcasted_dims(b, a) + (('x', 'y', 'z'), (5, 3, 4)) >>> a = NamedArray(("x", "y", "z"), np.zeros((5, 3, 4))) >>> b = NamedArray(("x", "y", "z"), np.zeros((0, 3, 4))) @@ -230,37 +252,31 @@ def _get_broadcasted_dims(*arrays: NamedArray) -> tuple[_Dims, _Shape]: >>> _get_broadcasted_dims(a, b) Traceback (most recent call last): ... - ValueError: operands cannot be broadcast together with mismatched lengths for dimension 'x': (5, 2) + ValueError: operands could not be broadcast together with dims = (('x', 'y', 'z'), ('x', 'y', 'z')) and shapes = ((5, 3, 4), (2, 3, 4)) """ - - def broadcastable(e1: int, e2: int) -> bool: - # out = e1 > 1 and e2 <= 1 - # out |= e2 > 1 and e1 <= 1 - - # out = e1 >= 0 and e2 <= 1 - # out |= e2 >= 0 and e1 <= 1 - - out = e1 <= 1 or e2 <= 1 - - return out - - # validate dimensions - all_dims = {} - for x in arrays: - _dims = x.dims - _raise_if_any_duplicate_dimensions(_dims, err_context="Broadcasting") - - for d, s in zip(_dims, x.shape): - if d not in all_dims: - all_dims[d] = s - elif all_dims[d] != s: - if broadcastable(all_dims[d], s): - max(all_dims[d], s) - else: - raise ValueError( - "operands cannot be broadcast together " - f"with mismatched lengths for dimension {d!r}: {(all_dims[d], s)}" - ) - - # TODO: Return flag whether broadcasting is needed? - return tuple(all_dims.keys()), tuple(all_dims.values()) + dims = tuple(a.dims for a in arrays) + shapes = tuple(a.shape for a in arrays) + + if len(shapes) == 1: + return shapes[0] + + out_dims = [] + out_shape = [] + for d, sizes in zip( + zip_longest(*map(reversed, dims), fillvalue=_default), + zip_longest(*map(reversed, shapes), fillvalue=-1), + ): + _d = dict.fromkeys(d) + _d.pop(_default, None) + _d = list(_d) + + dim = None if any(_isnone(sizes)) else max(sizes) + + if any(i not in [-1, 0, 1, dim] for i in sizes) or len(_d) != 1: + raise ValueError( + f"operands could not be broadcast together with {dims = } and {shapes = }" + ) + + out_dims.append(_d[0]) + out_shape.append(dim) + return tuple(reversed(out_dims)), tuple(reversed(out_shape))