From 441adf559f0ba295705a496d45b85b3be7da7464 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 8 May 2024 16:02:11 -0700 Subject: [PATCH] reduce autotuning range and add second bmm to benchmark --- methods/pca_topk/kernel/benchmark.py | 30 ++++++++++++++++++---------- methods/pca_topk/kernel/pca_topk.py | 16 ++++++++++++--- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/methods/pca_topk/kernel/benchmark.py b/methods/pca_topk/kernel/benchmark.py index dd7edd5..75d35dd 100644 --- a/methods/pca_topk/kernel/benchmark.py +++ b/methods/pca_topk/kernel/benchmark.py @@ -4,18 +4,13 @@ from pca_topk import gather_outer_bmv_optimized, gather_inner_matrix_only_bmv_optimized from sparq import gather_outer_bmv, gather_inner_matrix_only_bmv - -B = 1 +B = 4 NH = 32 -S = 500 +S = 800 D = 128 dtype = torch.float16 -print("===== BENCHMARKING s.v with various sparsities =======") -print("Batch Size : ", B) -print("Number of Heads : ", NH) -print("Number of Key Tokens (or sequence length) : ", S) -print("Hidden dimension per head : ", D) + configs = [ triton.testing.Benchmark( @@ -25,7 +20,7 @@ # Possible values for `line_arg` # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. line_vals=["torch", "triton-optimized"], # Label name for the lines - line_names=["torch (full keys)", "Triton (Optimized)"], # Line styles + line_names=["torch (full keys and values)", "Triton (Optimized)"], # Line styles styles=[("black", "-"), ("blue", "-")], ylabel="TFLOPS", # Label name for the y-axis plot_name="matmul-performance-" + ("fp16 (time in ms)" ), # Name for the plot, used also as a file name for saving the plot. @@ -72,7 +67,22 @@ def benchmark_bmm2(sparsity, B, NH, S, D, provider): return ms, max_ms, min_ms + + +print("===== BENCHMARKING q@k.t() with various sparsities =======") +print("Batch Size : ", B) +print("Number of Heads : ", NH) +print("Number of Key Tokens (or sequence length) : ", S) +print("Hidden dimension per head : ", D) +result = benchmark_bmm1.run(print_data=True) + + + +print("===== BENCHMARKING s@v with various sparsities =======") +print("Batch Size : ", B) +print("Number of Heads : ", NH) +print("Number of Key Tokens (or sequence length) : ", S) +print("Hidden dimension per head : ", D) result = benchmark_bmm2.run(print_data=True) -print(result) diff --git a/methods/pca_topk/kernel/pca_topk.py b/methods/pca_topk/kernel/pca_topk.py index 5331af0..a3238fc 100644 --- a/methods/pca_topk/kernel/pca_topk.py +++ b/methods/pca_topk/kernel/pca_topk.py @@ -6,7 +6,7 @@ import triton.language as tl from torch import Tensor -def get_autotune_config(): +def get_autotune_config_outer(): return [ triton.Config({"n_chunk": 4}), triton.Config({"n_chunk": 8}), @@ -20,7 +20,7 @@ def get_autotune_config(): ] @triton.autotune( - configs=get_autotune_config(), + configs=get_autotune_config_outer(), key=['b', 'n', 'k'], ) @triton.jit @@ -110,8 +110,18 @@ def gather_outer_bmv_optimized(A: Tensor, B: Tensor, I: Tensor) -> Tensor: return Y +def get_autotune_config_inner(): + return [ + triton.Config({"n_chunk": 4}), + triton.Config({"n_chunk": 8}), + triton.Config({"n_chunk": 16}), + triton.Config({"n_chunk": 32}), + triton.Config({"n_chunk": 64}), + triton.Config({"n_chunk": 128}), + ] + @triton.autotune( - configs=get_autotune_config(), + configs=get_autotune_config_inner(), key=['b', 'n', 'k'], ) @triton.jit