diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 7eb57a02f926c..d587d706362dd 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -72,7 +72,10 @@ def generate_random_inputs( def check_correctness( - model: ModelProto, inputs: Optional[Dict[str, np.ndarray]] = None, opset: int = None + model: ModelProto, + inputs: Optional[Dict[str, np.ndarray]] = None, + opset: int = None, + atol: float = 1e-5, ) -> None: """Run an onnx model in both onnxruntime and TVM through our importer confirm that the results match. Otherwise, an exception will be raised. @@ -85,6 +88,9 @@ def check_correctness( An optional dictionary containing values for each input in the onnx model. opset: int The opset version to use for the onnx importer. + atol: float + Set the tolerance of correctness checking. Some ops may be show more + arithmetic variance than others. """ if opset is not None: model.opset_import[0].version = opset @@ -143,7 +149,7 @@ def check_correctness( # TODO Allow configurable tolerance. # Sometimes None is used to indicate an unused output. if ort_out is not None: - tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, atol=1e-5) + tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, atol=atol) @pytest.mark.parametrize( @@ -933,7 +939,8 @@ def verify_reduce_func(func, data, axis, keepdims): model = helper.make_model(graph, producer_name="reduce_test") inputs_dict = {"x": data} - check_correctness(model, inputs_dict, opset=11) + # Reduction ops accumulate arithmetic errors, so we use a higher tolerance. + check_correctness(model, inputs_dict, opset=11, atol=1e-4) for keepdims in [True, False]: verify_reduce_func(