Skip to content

Commit

Permalink
refine code
Browse files Browse the repository at this point in the history
  • Loading branch information
shadow150519 committed Dec 3, 2024
1 parent 7481440 commit 7e73ee2
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 19 deletions.
23 changes: 11 additions & 12 deletions torch_xla/csrc/flash_attention_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include <iostream>

#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
Expand Down Expand Up @@ -390,15 +388,15 @@ FlashAttentionBackwardParams get_flash_attention_backward_params(
TORCH_CHECK(cu_seqlens_q.value().is_contiguous());
TORCH_CHECK(cu_seqlens_k.value().is_contiguous());
TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1 ||
batch_size == 1);
batch_size == 1); // now pack qkv batch size only support 1, maybe need to change in the future
TORCH_CHECK(
cu_seqlens_q.value().sizes() == torch::IntArrayRef({batch_size + 1}) ||
cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q + 1}),
"cu_seqlens_q shape should be batch_size+1 or seqlen_q");
cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q*batch_size + 1}),
"cu_seqlens_q shape should be batch_size+1 or seqlen_q+1");
TORCH_CHECK(
cu_seqlens_k.value().sizes() == torch::IntArrayRef({batch_size + 1}) ||
cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k + 1}),
"cu_seqlens_k shape should be batch_size+1 or seqlen_k");
cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k*batch_size + 1}),
"cu_seqlens_k shape should be batch_size+1 or seqlen_k+1");
}

int alibi_slopes_batch_stride = 0;
Expand Down Expand Up @@ -451,11 +449,11 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size,

auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA);
torch::Tensor rows = torch::arange(shape[0], opts.dtype(torch::kInt32))
.unsqueeze(1); // (batch_size,1)
.unsqueeze(1);
torch::Tensor cols = torch::arange(shape[1], opts.dtype(torch::kInt32))
.unsqueeze(0); // (1,seqlen)
.unsqueeze(0);
torch::Tensor mask =
cols < nonzero_counts.unsqueeze(1); // (1,seqlen) < (batch_size, 1)
cols < nonzero_counts.unsqueeze(1);
max_seqlen_in_batch = torch::sum(mask, {1}).max().item<int>();

torch::Tensor matrix = torch::zeros(shape, opts.dtype(torch::kInt32));
Expand Down Expand Up @@ -496,7 +494,7 @@ torch::Tensor unpad_softmax_lse(
at::Tensor result = at::empty({total, nhead}, pad_softmax_lse.options());
result.copy_(pad_softmax_lse.transpose(1, 2)
.reshape({batch_size * max_seqlen, nhead})
.index({indices, torch::indexing::Slice()}));
.index({indices, torch::indexing::Slice()})); // if packed tensor's batch size > 1 is supported in the future, need to modify here in the future
return result.transpose(0, 1).unsqueeze(0);
}

Expand All @@ -514,7 +512,7 @@ torch::Tensor pad_softmax_lse(
"indice should be same size with softmax_lse")

at::Tensor result =
at::empty({batch_size * max_seq_len, nheads}, softmax_lse.options());
at::zeros({batch_size * max_seq_len, nheads}, softmax_lse.options());

result.index_put_({indices, torch::indexing::Slice()},
softmax_lse.squeeze(0).transpose(0, 1));
Expand All @@ -527,6 +525,7 @@ at::Tensor position_ids_to_indices(const at::Tensor& position_ids,
int& max_seqlen_in_batch, int& total,
at::Tensor& cu_seqlen,
int& real_batch_size) {
cu_seqlen.fill_(-1);
at::Tensor flatten_position_ids = position_ids.flatten();
at::Tensor indices =
torch::arange(flatten_position_ids.size(0),
Expand Down
2 changes: 0 additions & 2 deletions torch_xla/csrc/ops/flash_attention_varlen_forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#include <iostream>

#include "cutlass/numeric_types.h"
#include "flash.h"
#include "static_switch.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ void custom_call_flash_attention_varlen_position_ids_backward(
at::Tensor softmax_lse =
torch::from_blob(buffers[5], {params.b, params.h, params.seqlen_q},
opts.dtype(torch::kFloat));
at::Tensor cu_seqlens_q = torch::from_blob(buffers[6], {params.seqlen_q + 1},
at::Tensor cu_seqlens_q = torch::from_blob(buffers[6], {params.b*params.seqlen_q + 1},
opts.dtype(torch::kInt32));
at::Tensor cu_seqlens_k = torch::from_blob(buffers[7], {params.seqlen_k + 1},
at::Tensor cu_seqlens_k = torch::from_blob(buffers[7], {params.b*params.seqlen_k + 1},
opts.dtype(torch::kInt32));

// Outputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>

#include <iostream>

#include "cutlass/numeric_types.h"
#include "flash.h"
#include "static_switch.h"
Expand Down Expand Up @@ -96,7 +94,6 @@ void custom_call_flash_attention_varlen_position_ids_forward(
torch::from_blob(buffers[6 + buf_offset], {2}, opts.dtype(torch::kInt64));
softmax_lse.fill_(0);
o_output.fill_(0);
cu_seqlens_k.fill_(-1);

int max_seqlen_in_batch_k = params.seqlen_k;
int total_k = params.b * params.seqlen_k;
Expand Down

0 comments on commit 7e73ee2

Please sign in to comment.