diff --git a/einx/op/index.py b/einx/op/index.py index d5a2fa7..92e748e 100644 --- a/einx/op/index.py +++ b/einx/op/index.py @@ -35,13 +35,17 @@ def index_stage3(exprs_in, tensors_in, expr_out, op=None, backend=None): for expr in root.all(): if isinstance(expr, einx.expr.stage3.Concatenation): raise ValueError("Concatenation not allowed") + if not with_update: + for expr in expr_out.all(): + if einx.expr.stage3.is_marked(expr): + raise ValueError("Brackets in the output expression are not allowed") exprs_in = list(exprs_in) marked_coordinate_axes = [expr for expr in exprs_in[1].all() if isinstance(expr, einx.expr.stage3.Axis) and einx.expr.stage3.is_marked(expr)] if len(marked_coordinate_axes) > 1: - raise ValueError(f"Expected at most one coordinate axis, got {len(marked_coordinate_axes)}") + raise ValueError(f"Expected at most one coordinate axis in the second expression, got {len(marked_coordinate_axes)}") ndim = marked_coordinate_axes[0].value if len(marked_coordinate_axes) == 1 else 1 - coordinate_axis_name = marked_coordinate_axes[0].name if len(marked_coordinate_axes) == 1 else None + coordinate_axis_name = marked_coordinate_axes[0].name if len(marked_coordinate_axes) == 1 and (not marked_coordinate_axes[0].is_unnamed or marked_coordinate_axes[0].value != 1) else None marked_tensor_axis_names = set(expr.name for expr in exprs_in[0].all() if isinstance(expr, einx.expr.stage3.Axis) and einx.expr.stage3.is_marked(expr)) if len(marked_tensor_axis_names) != ndim: diff --git a/test/test_shapes.py b/test/test_shapes.py index 7e469a1..ada05cf 100644 --- a/test/test_shapes.py +++ b/test/test_shapes.py @@ -373,6 +373,14 @@ def test_shape_index(backend): with pytest.raises(Exception): einx.get_at("b ([1 1]) c, b p [2] -> b p c", x, y) + x = backend.zeros((4, 5, 6)) + y = backend.cast(backend.zeros((4, 5)), coord_dtype) + assert einx.get_at("b t [d], b t -> b t", x, y).shape == (4, 5) + assert einx.get_at("... [d], ... -> ...", x, y).shape == (4, 5) + assert einx.get_at("b t [d], b (t [1]) -> b (t 1)", x, y).shape == (4, 5) + with pytest.raises(ValueError): + einx.get_at("b t [d], b (t [1]) -> b (t [1])", x, y) + @pytest.mark.parametrize("backend", backends) def test_shape_vmap_with_axis(backend): x = backend.ones((10, 10), "float32") diff --git a/test/test_values.py b/test/test_values.py index 83e2f1d..24ff80a 100644 --- a/test/test_values.py +++ b/test/test_values.py @@ -112,6 +112,14 @@ def test_values(backend): backend.to_tensor(np.stack(np.meshgrid(np.arange(6), np.arange(5), indexing="xy"), axis=-1).astype("int32")), ) + coord_dtype = "int32" if backend.name != "torch" else "long" + x = backend.to_tensor(rng.uniform(size=(4, 5, 6)).astype("float32")) + y = backend.cast(backend.ones((4, 5)) * 3, coord_dtype) + assert backend.allclose( + einx.get_at("... [d], ... -> ...", x, y), + x[:, :, 3], + ) + @pytest.mark.parametrize("backend", backends) def test_compare_backends(backend): x = np.random.uniform(size=(10, 3, 10)).astype("float32")