From 80673c701661dfff979f59911e62952aaef239b7 Mon Sep 17 00:00:00 2001 From: Florian Fervers Date: Sat, 9 Mar 2024 22:10:07 +0100 Subject: [PATCH] Better error reporting --- einx/op/index.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/einx/op/index.py b/einx/op/index.py index 3c7b604..4255a36 100644 --- a/einx/op/index.py +++ b/einx/op/index.py @@ -354,6 +354,24 @@ def after_stage2(exprs1, exprs2): )[: len(op[0]) + 1] exprs_in, expr_out = exprs[: len(op[0])], exprs[len(op[0])] + if update: + # Check that all axes in first input expression also appear in output expression + axes_in = { + axis.name + for axis in exprs_in[0].all() + if isinstance(axis, einx.expr.stage3.Axis) + } + axes_out = { + axis.name + for axis in expr_out.all() + if isinstance(axis, einx.expr.stage3.Axis) + } + if not axes_in.issubset(axes_out): + raise ValueError( + f"Output expression does not contain all axes from first input expression: " + f"{axes_in - axes_out}" + ) + return exprs_in, expr_out