From 7e73ee211eaf75788e3b83244d760dd635cc7f09 Mon Sep 17 00:00:00 2001 From: tianxingwang Date: Tue, 3 Dec 2024 14:46:28 +0800 Subject: [PATCH] refine code --- torch_xla/csrc/flash_attention_utils.cpp | 23 +++++++++---------- .../ops/flash_attention_varlen_forward.cpp | 2 -- ...attention_varlen_position_ids_backward.cpp | 4 ++-- ..._attention_varlen_position_ids_forward.cpp | 3 --- 4 files changed, 13 insertions(+), 19 deletions(-) diff --git a/torch_xla/csrc/flash_attention_utils.cpp b/torch_xla/csrc/flash_attention_utils.cpp index 950eaf50ce39..376f492a1bb9 100644 --- a/torch_xla/csrc/flash_attention_utils.cpp +++ b/torch_xla/csrc/flash_attention_utils.cpp @@ -3,8 +3,6 @@ #include #include -#include - #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -390,15 +388,15 @@ FlashAttentionBackwardParams get_flash_attention_backward_params( TORCH_CHECK(cu_seqlens_q.value().is_contiguous()); TORCH_CHECK(cu_seqlens_k.value().is_contiguous()); TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1 || - batch_size == 1); + batch_size == 1); // now pack qkv batch size only support 1, maybe need to change in the future TORCH_CHECK( cu_seqlens_q.value().sizes() == torch::IntArrayRef({batch_size + 1}) || - cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q + 1}), - "cu_seqlens_q shape should be batch_size+1 or seqlen_q"); + cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q*batch_size + 1}), + "cu_seqlens_q shape should be batch_size+1 or seqlen_q+1"); TORCH_CHECK( cu_seqlens_k.value().sizes() == torch::IntArrayRef({batch_size + 1}) || - cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k + 1}), - "cu_seqlens_k shape should be batch_size+1 or seqlen_k"); + cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k*batch_size + 1}), + "cu_seqlens_k shape should be batch_size+1 or seqlen_k+1"); } int alibi_slopes_batch_stride = 0; @@ -451,11 +449,11 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size, auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA); torch::Tensor rows = torch::arange(shape[0], opts.dtype(torch::kInt32)) - .unsqueeze(1); // (batch_size,1) + .unsqueeze(1); torch::Tensor cols = torch::arange(shape[1], opts.dtype(torch::kInt32)) - .unsqueeze(0); // (1,seqlen) + .unsqueeze(0); torch::Tensor mask = - cols < nonzero_counts.unsqueeze(1); // (1,seqlen) < (batch_size, 1) + cols < nonzero_counts.unsqueeze(1); max_seqlen_in_batch = torch::sum(mask, {1}).max().item(); torch::Tensor matrix = torch::zeros(shape, opts.dtype(torch::kInt32)); @@ -496,7 +494,7 @@ torch::Tensor unpad_softmax_lse( at::Tensor result = at::empty({total, nhead}, pad_softmax_lse.options()); result.copy_(pad_softmax_lse.transpose(1, 2) .reshape({batch_size * max_seqlen, nhead}) - .index({indices, torch::indexing::Slice()})); + .index({indices, torch::indexing::Slice()})); // if packed tensor's batch size > 1 is supported in the future, need to modify here in the future return result.transpose(0, 1).unsqueeze(0); } @@ -514,7 +512,7 @@ torch::Tensor pad_softmax_lse( "indice should be same size with softmax_lse") at::Tensor result = - at::empty({batch_size * max_seq_len, nheads}, softmax_lse.options()); + at::zeros({batch_size * max_seq_len, nheads}, softmax_lse.options()); result.index_put_({indices, torch::indexing::Slice()}, softmax_lse.squeeze(0).transpose(0, 1)); @@ -527,6 +525,7 @@ at::Tensor position_ids_to_indices(const at::Tensor& position_ids, int& max_seqlen_in_batch, int& total, at::Tensor& cu_seqlen, int& real_batch_size) { + cu_seqlen.fill_(-1); at::Tensor flatten_position_ids = position_ids.flatten(); at::Tensor indices = torch::arange(flatten_position_ids.size(0), diff --git a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp index 92b06f1addf5..832e6c280faf 100644 --- a/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp +++ b/torch_xla/csrc/ops/flash_attention_varlen_forward.cpp @@ -4,8 +4,6 @@ #include #include -#include - #include "cutlass/numeric_types.h" #include "flash.h" #include "static_switch.h" diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp index 5cb54805a30f..132742921833 100644 --- a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_backward.cpp @@ -87,9 +87,9 @@ void custom_call_flash_attention_varlen_position_ids_backward( at::Tensor softmax_lse = torch::from_blob(buffers[5], {params.b, params.h, params.seqlen_q}, opts.dtype(torch::kFloat)); - at::Tensor cu_seqlens_q = torch::from_blob(buffers[6], {params.seqlen_q + 1}, + at::Tensor cu_seqlens_q = torch::from_blob(buffers[6], {params.b*params.seqlen_q + 1}, opts.dtype(torch::kInt32)); - at::Tensor cu_seqlens_k = torch::from_blob(buffers[7], {params.seqlen_k + 1}, + at::Tensor cu_seqlens_k = torch::from_blob(buffers[7], {params.b*params.seqlen_k + 1}, opts.dtype(torch::kInt32)); // Outputs diff --git a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp index 945a807236cc..23d5d1c4cc6f 100644 --- a/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp +++ b/torch_xla/csrc/ops/flash_attention_varlen_position_ids_forward.cpp @@ -4,8 +4,6 @@ #include #include -#include - #include "cutlass/numeric_types.h" #include "flash.h" #include "static_switch.h" @@ -96,7 +94,6 @@ void custom_call_flash_attention_varlen_position_ids_forward( torch::from_blob(buffers[6 + buf_offset], {2}, opts.dtype(torch::kInt64)); softmax_lse.fill_(0); o_output.fill_(0); - cu_seqlens_k.fill_(-1); int max_seqlen_in_batch_k = params.seqlen_k; int total_k = params.b * params.seqlen_k;