From 941b841052ccb97f56f99e9d2b4bc1fe3e4b1301 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 17 Jul 2024 16:28:30 -0700 Subject: [PATCH 1/4] Fix coflax_cutlass build --- userbenchmark/triton/cutlass_kernels/install.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/userbenchmark/triton/cutlass_kernels/install.py b/userbenchmark/triton/cutlass_kernels/install.py index 471853a9d8..6ccdf787ac 100644 --- a/userbenchmark/triton/cutlass_kernels/install.py +++ b/userbenchmark/triton/cutlass_kernels/install.py @@ -6,7 +6,7 @@ CUDA_HOME = "/usr/local/cuda" if not "CUDA_HOME" in os.environ else os.environ["CUDA_HOME"] REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent.parent FBGEMM_PATH = REPO_PATH.joinpath("submodules", "FBGEMM", "fbgemm_gpu") -FBGEMM_CUTLASS_PATH = FBGEMM_PATH.parent.joinpath("third_party", "cutlass") +FBGEMM_CUTLASS_PATH = FBGEMM_PATH.parent.joinpath("external", "cutlass") COLFAX_CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass-kernels") COLFAX_CUTLASS_TRITONBENCH_PATH = REPO_PATH.joinpath("userbenchmark", "triton", "cutlass_kernels") From 5334eb2b93ecadfe344ab7beb00a1cfe4a0a720d Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 31 Jul 2024 17:48:56 -0400 Subject: [PATCH 2/4] Fix the colfax_cutlass build --- .../cutlass_kernels/include/fmha_forward.h | 24 ------------ .../triton/cutlass_kernels/install.py | 5 +-- .../cutlass_kernels/src/fmha/register_op.cu | 37 +++++++++---------- 3 files changed, 19 insertions(+), 47 deletions(-) delete mode 100644 userbenchmark/triton/cutlass_kernels/include/fmha_forward.h diff --git a/userbenchmark/triton/cutlass_kernels/include/fmha_forward.h b/userbenchmark/triton/cutlass_kernels/include/fmha_forward.h deleted file mode 100644 index 35caedaf3a..0000000000 --- a/userbenchmark/triton/cutlass_kernels/include/fmha_forward.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -template -void fmhaForwardDevice( - int SEQLEN, - int KEYLEN, - int NUMHEADS, - int BATCH, - PrecType const* tensorQ, - PrecType const* tensorK, - OutputType const* tensorV, - OutputType* tensorS, - OutputType* tensorO, - AccumType* miOut, - AccumType* sPrimeOut, - int iterations, - float scale, - cudaStream_t stream = 0); diff --git a/userbenchmark/triton/cutlass_kernels/install.py b/userbenchmark/triton/cutlass_kernels/install.py index 6ccdf787ac..a4b9aaf39d 100644 --- a/userbenchmark/triton/cutlass_kernels/install.py +++ b/userbenchmark/triton/cutlass_kernels/install.py @@ -37,6 +37,7 @@ COMPILER_FLAGS = [ f"-I{str(COLFAX_CUTLASS_PATH.joinpath('lib').resolve())}", f"-I{str(COLFAX_CUTLASS_PATH.joinpath('include').resolve())}", + f"-I{str(COLFAX_CUTLASS_PATH.joinpath('src', 'fmha').resolve())}", f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('include').resolve())}", f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('examples', 'commmon').resolve())}", f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('tools', 'util', 'include').resolve())}", @@ -63,9 +64,7 @@ "-ldl", ] FMHA_SOURCES = [ - # Source 1 - f"{str(COLFAX_CUTLASS_PATH.joinpath('src', 'fmha', 'fmha_forward.cu').resolve())}", - # Source 2 + # Source f"{str(COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath('src', 'fmha', 'register_op.cu').resolve())}", "-o", "fmha_forward_lib.so", diff --git a/userbenchmark/triton/cutlass_kernels/src/fmha/register_op.cu b/userbenchmark/triton/cutlass_kernels/src/fmha/register_op.cu index 94b1c4c484..1a590acf4f 100644 --- a/userbenchmark/triton/cutlass_kernels/src/fmha/register_op.cu +++ b/userbenchmark/triton/cutlass_kernels/src/fmha/register_op.cu @@ -21,9 +21,9 @@ // #include "autogen/cutlassF.h" #include "pytorch_utils.h" -#include "fmha_forward.h" +#include "fmha_forward.cu" -template +template std::tuple fmha_forward( const int64_t& seq_length, @@ -31,7 +31,7 @@ fmha_forward( const int64_t& batch, const at::Tensor& query, // [b, seqlen, num_heads, K] const at::Tensor& key, // [b, seqlen, num_heads, K] - const at::Tensor& value, // [b, seqlen, num_heads, Kv] + at::Tensor& value, // [b, seqlen, num_heads, Kv] const float& scale) { TORCH_CHECK(query.dim() == 4); TORCH_CHECK(key.dim() == 4); @@ -70,7 +70,7 @@ fmha_forward( query.options().dtype(CutlassToAtenDtype::atScalarType())); at::Tensor ret = at::empty( {B, M, num_heads, Kv}, - query.options().dtype(CutlassToAtenDtype::atScalarType())); + query.options().dtype(CutlassToAtenDtype::atScalarType())); using AccumType = float; // AccumType is always float. at::Tensor devMiOut = at::empty( @@ -80,16 +80,16 @@ fmha_forward( {B, M, num_heads}, query.options().dtype(CutlassToAtenDtype::atScalarType())); - fmhaForwardDevice( + fmhaForwardDevice( seq_length, key_length, num_heads, B, reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), - reinterpret_cast(value.data_ptr()), - reinterpret_cast(S.data_ptr()), - reinterpret_cast(ret.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(S.data_ptr()), + reinterpret_cast(ret.data_ptr()), reinterpret_cast(devMiOut.data_ptr()), reinterpret_cast(devSprimeOut.data_ptr()), 1, @@ -99,7 +99,7 @@ fmha_forward( return std::make_tuple(S, ret, devMiOut, devSprimeOut); } -template +template std::tuple launch_forward( const int64_t& seq_length, @@ -107,17 +107,17 @@ launch_forward( const int64_t& batch, const at::Tensor& query, // [b, seqlen, num_heads, K] const at::Tensor& key, // [b, seqlen, num_heads, K] - const at::Tensor& value, // [b, seqlen, num_heads, Kv] + at::Tensor& value, // [b, seqlen, num_heads, Kv] const double& scale, const int64_t& Kdim) { if (Kdim == 64) { - return fmha_forward( + return fmha_forward( seq_length, key_length, batch, query, key, value, scale); } else if (Kdim == 128) { - return fmha_forward( + return fmha_forward( seq_length, key_length, batch, query, key, value, scale); } else if (Kdim == 256) { - return fmha_forward( + return fmha_forward( seq_length, key_length, batch, query, key, value, scale); } throw std::runtime_error("Kdim wrong"); @@ -131,18 +131,15 @@ fmha_forward_dispatch( const int64_t& batch, const at::Tensor& query, // [b, seqlen, num_heads, K] const at::Tensor& key, // [b, seqlen, num_heads, K] - const at::Tensor& value, // [b, seqlen, num_heads, Kv] + at::Tensor& value, // [b, seqlen, num_heads, Kv] const double& scale) { int64_t Kdim = query.size(-1); if (query.scalar_type() == at::kHalf){ - return launch_forward(seq_length, key_length, batch, query, key, value, scale, Kdim); + return launch_forward(seq_length, key_length, batch, query, key, value, scale, Kdim); } else if (query.scalar_type() == at::kBFloat16){ - return launch_forward(seq_length, key_length, batch, query, key, value, scale, Kdim); - } - else if (query.scalar_type() == at::kFloat8_e4m3fn){ - return launch_forward(seq_length, key_length, batch, query, key, value, scale, Kdim); + return launch_forward(seq_length, key_length, batch, query, key, value, scale, Kdim); } else { std::cout << "unsupported data type: " << query.scalar_type() << std::endl; @@ -159,7 +156,7 @@ fmha_forward_dispatch_meta( const int64_t& batch, const at::Tensor& query, // [b, seqlen, num_heads, K] const at::Tensor& key, // [b, seqlen, num_heads, K] - const at::Tensor& value, // [b, seqlen, num_heads, Kv] + at::Tensor& value, // [b, seqlen, num_heads, Kv] const double& scale) { TORCH_CHECK(query.dim() == 4); From 34d80e497bda221c2e82eb9bbee73975f46a0ace Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 31 Jul 2024 17:56:50 -0400 Subject: [PATCH 3/4] Fix operator --- torchbenchmark/operators/flash_attention/operator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchbenchmark/operators/flash_attention/operator.py b/torchbenchmark/operators/flash_attention/operator.py index 3e08f79eca..6de38ddf6d 100644 --- a/torchbenchmark/operators/flash_attention/operator.py +++ b/torchbenchmark/operators/flash_attention/operator.py @@ -83,8 +83,8 @@ # colfax Flash Attention V2 for Hopper torch.ops.load_library("//ai_codesign/gen_ai/cutlass-kernels:fmha_forward_lib") else: - from userbenchmark.triton.utils import load_library - load_library("colfax_cutlass/fmha_forward_lib.so") + from userbenchmark.triton.loader import load_library + load_library("cutlass_kernels/fmha_forward_lib.so") colfax_cutlass_fmha = torch.ops.cutlass.fmha_forward except (ImportError, IOError, AttributeError): colfax_cutlass_fmha = None @@ -128,6 +128,7 @@ class Operator(BenchmarkOperator): def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): super().__init__(tb_args, extra_args) args = parse_op_args(self.extra_args) + self.use_cuda_graphs = False self.BATCH = args.batch self.H = args.n_heads self.D_HEAD = args.d_head From b7c6ad7213625943d92e4fb4a7b23ff4e2615ab9 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Wed, 31 Jul 2024 21:55:58 -0400 Subject: [PATCH 4/4] Still use cudagraphs on colfax_cutlass --- torchbenchmark/operators/flash_attention/operator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchbenchmark/operators/flash_attention/operator.py b/torchbenchmark/operators/flash_attention/operator.py index 6de38ddf6d..53e77670a8 100644 --- a/torchbenchmark/operators/flash_attention/operator.py +++ b/torchbenchmark/operators/flash_attention/operator.py @@ -128,7 +128,6 @@ class Operator(BenchmarkOperator): def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None): super().__init__(tb_args, extra_args) args = parse_op_args(self.extra_args) - self.use_cuda_graphs = False self.BATCH = args.batch self.H = args.n_heads self.D_HEAD = args.d_head