diff --git a/einx/backend/_dask.py b/einx/backend/_dask.py index adc871f..e8eae2b 100644 --- a/einx/backend/_dask.py +++ b/einx/backend/_dask.py @@ -33,7 +33,50 @@ def reshape(tensor, shape): transpose = op.transpose(tda.transpose) broadcast_to = op.broadcast_to(tda.broadcast_to) - einsum = op.einsum(tda.einsum) + + @classmethod + @einx.trace + def einsum(backend, equation, *operands): + exprs = equation.split("->") + if len(exprs) != 2: + raise ValueError("Invalid einsum equation") + in_exprs = exprs[0].split(",") + out_expr = exprs[1] + + # Remove scalars + scalars = [] + for in_expr, operand in zip(in_exprs, operands): + if (len(in_expr) == 0) != (operand.shape == ()): + raise ValueError( + f"Tensor and einsum expression do not match: {in_expr} and {operand.shape}" + ) + if operand.shape == (): + scalars.append(operand) + operands = [operand for operand in operands if operand.shape != ()] + in_exprs = [in_expr for in_expr in in_exprs if len(in_expr) > 0] + assert len(in_exprs) == len(operands) + equation = ",".join(in_exprs) + "->" + out_expr + + # Call without scalars + if len(operands) == 1: + if in_exprs[0] != out_expr: + output = op.einsum(tda.einsum)(equation, *operands) + else: + output = operands[0] + elif len(operands) > 1: + output = op.einsum(tda.einsum)(equation, *operands) + else: + output = None + + # Multiply scalars + if len(scalars) > 0: + if output is None: + output = backend.multiply(*scalars) + else: + output = backend.multiply(output, *scalars) + + return output + arange = op.arange(tda.arange) @staticmethod