Skip to content

Commit

Permalink
[torchlib] Fix prod (#2038)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Jan 24, 2025
1 parent dbf2353 commit 84dfcad
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
16 changes: 13 additions & 3 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6682,11 +6682,21 @@ def aten_prelu_backward(
raise NotImplementedError()


@torch_op("aten::prod.dim_int", trace_only=True)
def aten_prod(self: TReal, dim: int, keepdim: bool = False) -> TReal:
@torch_op("aten::prod", trace_only=True)
def aten_prod(self: TReal, dtype: int = -1) -> TReal:
"""prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"""

# Todo: add test for this function later
if dtype != -1 and dtype is not None:
self = op.Cast(self, to=dtype)
return op.ReduceProd(self)


@torch_op("aten::prod.dim_int", trace_only=True)
def aten_prod_dim_int(self: TReal, dim: int, keepdim: bool = False, dtype: int = -1) -> TReal:
"""prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""

if dtype != -1 and dtype is not None:
self = op.Cast(self, to=dtype)
return op.ReduceProd(self, axes=[dim], keepdims=keepdim)


Expand Down
14 changes: 14 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,19 @@ def _where_input_wrangler(
),
TorchLibOpInfo("polar", core_ops.aten_polar),
TorchLibOpInfo("pow", core_ops.aten_pow),
TorchLibOpInfo("prod", core_ops.aten_prod).skip(
matcher=lambda sample: sample.kwargs.get("dim") is not None
or sample.kwargs.get("keepdim") is not None
or sample.kwargs.get("dtype") != -1,
reason="this Aten overload only accept 1 inputs: self",
),
TorchLibOpInfo("prod_dim_int", core_ops.aten_prod_dim_int).skip(
matcher=lambda sample: (
sample.kwargs.get("dim") is None and sample.kwargs.get("keepdim") is None
)
or sample.kwargs.get("dtype") != -1,
reason="this Aten overload can accept 3 inputs:(self, dim, keepdim)",
),
TorchLibOpInfo("nn.functional.prelu", core_ops.aten_prelu),
TorchLibOpInfo("ops.aten.rand", core_ops.aten_rand, nondeterministic=True),
TorchLibOpInfo("ops.aten.rand_like", core_ops.aten_rand_like, nondeterministic=True),
Expand Down Expand Up @@ -2203,6 +2216,7 @@ def _where_input_wrangler(
OPS_DB, "ops.aten._log_softmax", ("ops.aten._log_softmax_half",)
)
ops_test_common.duplicate_opinfo(OPS_DB, "ops.aten._softmax", ("ops.aten._softmax_half",))
ops_test_common.duplicate_opinfo(OPS_DB, "prod", ("prod_dim_int",))
ops_test_common.duplicate_opinfo(OPS_DB, "round", ("round_decimals",))
ops_test_common.duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "view_as_complex", ("view_as_complex_copy",))
Expand Down

0 comments on commit 84dfcad

Please sign in to comment.