Skip to content

Commit

Permalink
Zero dim fix
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Jan 27, 2025
1 parent a73f1b3 commit 9d07625
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
20 changes: 19 additions & 1 deletion ndonnx/_opset_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,25 @@ def clip(input: _CoreArray, min: _CoreArray, max: _CoreArray) -> _CoreArray:

@eager_propagate
def matmul(a: _CoreArray, b: _CoreArray) -> _CoreArray:
return _CoreArray(op.matmul(a.var, b.var))
# TODO(adityagoel4512): this requires an upstream patch in onnxruntime
# onnxruntime goes into UB with zero size inputs
(out,) = op.if_(
op.equal(op.size(a.var), op.const(0, dtype=np.int64)),
then_branch=lambda: [
op.const(
np.zeros(
(),
dtype=np.result_type(
a.var.unwrap_tensor().dtype, b.var.unwrap_tensor().dtype
),
)
),
],
else_branch=lambda: [
op.matmul(a.var, b.var),
],
)
return _CoreArray(out)


@eager_propagate
Expand Down
23 changes: 23 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,3 +1067,26 @@ def test_repeat_raises(a, repeats, axis):
def test_getitem_bool_raises(x, index):
with pytest.raises(IndexError):
x[index]


@pytest.mark.parametrize(
"x, y",
[
(
ndx.asarray([], dtype=ndx.uint8),
ndx.asarray([], dtype=ndx.uint8),
),
(
ndx.asarray([], dtype=ndx.float32),
ndx.asarray([], dtype=ndx.int16),
),
(
ndx.asarray([1, 2, 3], dtype=ndx.uint8),
ndx.asarray([1, 2, 3], dtype=ndx.float32),
),
],
)
def test_matmul_zero_dims(x, y):
ndx_result = x @ y
np_result = x.to_numpy() @ y.to_numpy()
assert_array_equal(ndx_result.to_numpy(), np_result)

0 comments on commit 9d07625

Please sign in to comment.