From e9de1aec8b1f86da20c128f4021a54351ccadcf6 Mon Sep 17 00:00:00 2001 From: Tom White Date: Thu, 20 Jun 2024 11:25:16 +0100 Subject: [PATCH] Ensure integer arrays used for indexing are NumPy arrays (#485) --- cubed/core/ops.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index c83a265c..b62de566 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -15,6 +15,7 @@ from toolz import accumulate, map from cubed import config +from cubed.backend_array_api import backend_array_to_numpy_array from cubed.backend_array_api import namespace as nxp from cubed.backend_array_api import numpy_array_to_backend_array from cubed.core.array import CoreArray, check_array_specs, compute, gensym @@ -408,8 +409,11 @@ def index(x, key): key = (key,) # Replace Cubed arrays with NumPy arrays - note that this may trigger a computation! + # Note that NumPy arrays are needed for ndindex. key = tuple( - dim_sel.compute() if isinstance(dim_sel, CoreArray) else dim_sel + backend_array_to_numpy_array(dim_sel.compute()) + if isinstance(dim_sel, CoreArray) + else dim_sel for dim_sel in key )