From 5b3468fe427cb456a7296d17d192d18a239f712c Mon Sep 17 00:00:00 2001 From: StrongSpoon <35829812+StrongSpoon@users.noreply.github.com> Date: Thu, 2 Jan 2025 15:18:22 +0800 Subject: [PATCH] [benchmark] skip perf test of cummin when triton < 3.0 (#385) --- benchmark/performance_utils.py | 25 +++++++++++++++++++++++++ benchmark/test_reduction_perf.py | 8 +++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/benchmark/performance_utils.py b/benchmark/performance_utils.py index 2c762f8d..b30243f2 100644 --- a/benchmark/performance_utils.py +++ b/benchmark/performance_utils.py @@ -1,4 +1,5 @@ import gc +import importlib import logging import time from typing import Any, Generator, List, Optional, Tuple @@ -30,6 +31,30 @@ torch_backend_device.matmul.allow_tf32 = False +def SkipVersion(module_name, skip_pattern): + cmp = skip_pattern[0] + assert cmp in ("=", "<", ">"), f"Invalid comparison operator: {cmp}" + try: + M, N = skip_pattern[1:].split(".") + M, N = int(M), int(N) + except Exception: + raise ValueError("Cannot parse version number from skip_pattern.") + + try: + module = importlib.import_module(module_name) + version = module.__version__ + major, minor = map(int, version.split(".")[:2]) + except Exception: + raise ImportError(f"Cannot determine version of module: {module_name}") + + if cmp == "=": + return major == M and minor == N + elif cmp == "<": + return (major, minor) < (M, N) + else: + return (major, minor) > (M, N) + + class Benchmark: device: str = device DEFAULT_METRICS = DEFAULT_METRICS diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index 9ac53ef3..44e59bec 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -11,6 +11,7 @@ Benchmark, Config, GenericBenchmark2DOnly, + SkipVersion, generate_tensor_input, unary_input_fn, ) @@ -153,7 +154,12 @@ def cumsum_input_fn(shape, cur_dtype, device): torch.cummin, cumsum_input_fn, FLOAT_DTYPES + INT_DTYPES, - marks=pytest.mark.cummin, + marks=[ + pytest.mark.cummin, + pytest.mark.skipif( + SkipVersion("triton", "<3.0"), reason="triton not supported" + ), + ], ), ], )