Skip to content

Commit

Permalink
refine code
Browse files Browse the repository at this point in the history
  • Loading branch information
shadow150519 committed Dec 3, 2024
1 parent 7e73ee2 commit e531a14
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
31 changes: 19 additions & 12 deletions torch_xla/csrc/flash_attention_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,14 +388,17 @@ FlashAttentionBackwardParams get_flash_attention_backward_params(
TORCH_CHECK(cu_seqlens_q.value().is_contiguous());
TORCH_CHECK(cu_seqlens_k.value().is_contiguous());
TORCH_CHECK(batch_size == cu_seqlens_q.value().numel() - 1 ||
batch_size == 1); // now pack qkv batch size only support 1, maybe need to change in the future
batch_size == 1); // now pack qkv batch size only support 1,
// maybe need to change in the future
TORCH_CHECK(
cu_seqlens_q.value().sizes() == torch::IntArrayRef({batch_size + 1}) ||
cu_seqlens_q.value().sizes() == torch::IntArrayRef({seqlen_q*batch_size + 1}),
cu_seqlens_q.value().sizes() ==
torch::IntArrayRef({seqlen_q * batch_size + 1}),
"cu_seqlens_q shape should be batch_size+1 or seqlen_q+1");
TORCH_CHECK(
cu_seqlens_k.value().sizes() == torch::IntArrayRef({batch_size + 1}) ||
cu_seqlens_k.value().sizes() == torch::IntArrayRef({seqlen_k*batch_size + 1}),
cu_seqlens_k.value().sizes() ==
torch::IntArrayRef({seqlen_k * batch_size + 1}),
"cu_seqlens_k shape should be batch_size+1 or seqlen_k+1");
}

Expand Down Expand Up @@ -448,12 +451,11 @@ at::Tensor cu_seqlens_to_indices(const at::Tensor& cu_seqlens, int batch_size,
std::array<int64_t, 2> shape = {batch_size, seqlen};

auto opts = torch::TensorOptions().dtype(scalar_type).device(torch::kCUDA);
torch::Tensor rows = torch::arange(shape[0], opts.dtype(torch::kInt32))
.unsqueeze(1);
torch::Tensor cols = torch::arange(shape[1], opts.dtype(torch::kInt32))
.unsqueeze(0);
torch::Tensor mask =
cols < nonzero_counts.unsqueeze(1);
torch::Tensor rows =
torch::arange(shape[0], opts.dtype(torch::kInt32)).unsqueeze(1);
torch::Tensor cols =
torch::arange(shape[1], opts.dtype(torch::kInt32)).unsqueeze(0);
torch::Tensor mask = cols < nonzero_counts.unsqueeze(1);
max_seqlen_in_batch = torch::sum(mask, {1}).max().item<int>();

torch::Tensor matrix = torch::zeros(shape, opts.dtype(torch::kInt32));
Expand Down Expand Up @@ -492,9 +494,14 @@ torch::Tensor unpad_softmax_lse(
cu_seqlens_to_indices(valid_cu_seqlens, batch_size, max_seqlen,
torch::kInt64, max_seqlen_in_batch, total);
at::Tensor result = at::empty({total, nhead}, pad_softmax_lse.options());
result.copy_(pad_softmax_lse.transpose(1, 2)
.reshape({batch_size * max_seqlen, nhead})
.index({indices, torch::indexing::Slice()})); // if packed tensor's batch size > 1 is supported in the future, need to modify here in the future
result.copy_(
pad_softmax_lse.transpose(1, 2)
.reshape({batch_size * max_seqlen, nhead})
.index({indices,
torch::indexing::Slice()})); // if packed tensor's batch size
// > 1 is supported in the
// future, need to modify here
// in the future
return result.transpose(0, 1).unsqueeze(0);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ void custom_call_flash_attention_varlen_position_ids_backward(
at::Tensor softmax_lse =
torch::from_blob(buffers[5], {params.b, params.h, params.seqlen_q},
opts.dtype(torch::kFloat));
at::Tensor cu_seqlens_q = torch::from_blob(buffers[6], {params.b*params.seqlen_q + 1},
opts.dtype(torch::kInt32));
at::Tensor cu_seqlens_k = torch::from_blob(buffers[7], {params.b*params.seqlen_k + 1},
opts.dtype(torch::kInt32));
at::Tensor cu_seqlens_q = torch::from_blob(
buffers[6], {params.b * params.seqlen_q + 1}, opts.dtype(torch::kInt32));
at::Tensor cu_seqlens_k = torch::from_blob(
buffers[7], {params.b * params.seqlen_k + 1}, opts.dtype(torch::kInt32));

// Outputs
at::Tensor dq =
Expand Down

0 comments on commit e531a14

Please sign in to comment.