Skip to content

Commit

Permalink
Fix bug when using "1" as coordinate axis in einx.index
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Jan 18, 2024
1 parent 88006a7 commit 4273732
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
8 changes: 6 additions & 2 deletions einx/op/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,17 @@ def index_stage3(exprs_in, tensors_in, expr_out, op=None, backend=None):
for expr in root.all():
if isinstance(expr, einx.expr.stage3.Concatenation):
raise ValueError("Concatenation not allowed")
if not with_update:
for expr in expr_out.all():
if einx.expr.stage3.is_marked(expr):
raise ValueError("Brackets in the output expression are not allowed")
exprs_in = list(exprs_in)

marked_coordinate_axes = [expr for expr in exprs_in[1].all() if isinstance(expr, einx.expr.stage3.Axis) and einx.expr.stage3.is_marked(expr)]
if len(marked_coordinate_axes) > 1:
raise ValueError(f"Expected at most one coordinate axis, got {len(marked_coordinate_axes)}")
raise ValueError(f"Expected at most one coordinate axis in the second expression, got {len(marked_coordinate_axes)}")
ndim = marked_coordinate_axes[0].value if len(marked_coordinate_axes) == 1 else 1
coordinate_axis_name = marked_coordinate_axes[0].name if len(marked_coordinate_axes) == 1 else None
coordinate_axis_name = marked_coordinate_axes[0].name if len(marked_coordinate_axes) == 1 and (not marked_coordinate_axes[0].is_unnamed or marked_coordinate_axes[0].value != 1) else None

marked_tensor_axis_names = set(expr.name for expr in exprs_in[0].all() if isinstance(expr, einx.expr.stage3.Axis) and einx.expr.stage3.is_marked(expr))
if len(marked_tensor_axis_names) != ndim:
Expand Down
8 changes: 8 additions & 0 deletions test/test_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,14 @@ def test_shape_index(backend):
with pytest.raises(Exception):
einx.get_at("b ([1 1]) c, b p [2] -> b p c", x, y)

x = backend.zeros((4, 5, 6))
y = backend.cast(backend.zeros((4, 5)), coord_dtype)
assert einx.get_at("b t [d], b t -> b t", x, y).shape == (4, 5)
assert einx.get_at("... [d], ... -> ...", x, y).shape == (4, 5)
assert einx.get_at("b t [d], b (t [1]) -> b (t 1)", x, y).shape == (4, 5)
with pytest.raises(ValueError):
einx.get_at("b t [d], b (t [1]) -> b (t [1])", x, y)

@pytest.mark.parametrize("backend", backends)
def test_shape_vmap_with_axis(backend):
x = backend.ones((10, 10), "float32")
Expand Down
8 changes: 8 additions & 0 deletions test/test_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ def test_values(backend):
backend.to_tensor(np.stack(np.meshgrid(np.arange(6), np.arange(5), indexing="xy"), axis=-1).astype("int32")),
)

coord_dtype = "int32" if backend.name != "torch" else "long"
x = backend.to_tensor(rng.uniform(size=(4, 5, 6)).astype("float32"))
y = backend.cast(backend.ones((4, 5)) * 3, coord_dtype)
assert backend.allclose(
einx.get_at("... [d], ... -> ...", x, y),
x[:, :, 3],
)

@pytest.mark.parametrize("backend", backends)
def test_compare_backends(backend):
x = np.random.uniform(size=(10, 3, 10)).astype("float32")
Expand Down

0 comments on commit 4273732

Please sign in to comment.