diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5a5b288..203fad8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ on: # to the main branch and tags. push: branches: - - "main" + - "*" tags: - "*" diff --git a/ndonnx/_opset_extensions.py b/ndonnx/_opset_extensions.py index df6e8cc..4675503 100644 --- a/ndonnx/_opset_extensions.py +++ b/ndonnx/_opset_extensions.py @@ -437,7 +437,25 @@ def clip(input: _CoreArray, min: _CoreArray, max: _CoreArray) -> _CoreArray: @eager_propagate def matmul(a: _CoreArray, b: _CoreArray) -> _CoreArray: - return _CoreArray(op.matmul(a.var, b.var)) + # TODO(adityagoel4512): this requires an upstream patch in onnxruntime + # onnxruntime goes into UB with zero size inputs + (out,) = op.if_( + op.equal(op.size(a.var), op.const(0, dtype=np.int64)), + then_branch=lambda: [ + op.const( + np.zeros( + (), + dtype=np.result_type( + a.var.unwrap_tensor().dtype, b.var.unwrap_tensor().dtype + ), + ) + ), + ], + else_branch=lambda: [ + op.matmul(a.var, b.var), + ], + ) + return _CoreArray(out) @eager_propagate diff --git a/xfails.txt b/xfails.txt index fa73a81..e778add 100644 --- a/xfails.txt +++ b/xfails.txt @@ -44,7 +44,6 @@ array_api_tests/test_has_names.py::test_has_names[linalg-diagonal] array_api_tests/test_has_names.py::test_has_names[linalg-eigh] array_api_tests/test_has_names.py::test_has_names[linalg-eigvalsh] array_api_tests/test_has_names.py::test_has_names[linalg-inv] -array_api_tests/test_has_names.py::test_has_names[linalg-matmul] array_api_tests/test_has_names.py::test_has_names[linalg-matrix_norm] array_api_tests/test_has_names.py::test_has_names[linalg-matrix_power] array_api_tests/test_has_names.py::test_has_names[linalg-matrix_rank]