Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial enablement for text-embedding #758

Open
wants to merge 25 commits into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
56b42c3
Initial draft to enable embedding task.
libinta Jan 24, 2025
b62d611
remove ENCODER_ONLY
libinta Jan 24, 2025
a647baa
Added support for embedding model with self attention without causal …
libinta Jan 27, 2025
46e1aad
Change set_attn_bias padding element from -math.inf to -3e38 as -math…
libinta Jan 27, 2025
2f74e6b
rewrite is_causal and add dbg msg
libinta Jan 27, 2025
99947c8
update maskoff value
libinta Jan 27, 2025
094294c
fix wrong base mask
libinta Jan 28, 2025
1c7416f
cleanup code
libinta Jan 29, 2025
c6cdae1
cleanup code
libinta Jan 29, 2025
8ac281b
cleanup code
libinta Jan 29, 2025
e72c2f0
Add pooler support for padded batch inputs for hpu with CLSPoll, Last…
libinta Jan 30, 2025
7c1c74b
add meanpool for padded input
libinta Jan 30, 2025
5c49ca1
revert bert change
libinta Jan 30, 2025
ae6fbe0
modify meanpool for padded input
libinta Jan 30, 2025
d65340a
write is_pooler function
libinta Jan 30, 2025
0c28519
fix is_causal logic
libinta Jan 31, 2025
1fe398f
Set is_causal based on attn_type
libinta Feb 1, 2025
c3a92f3
Set is_causal based on attn_type
libinta Feb 1, 2025
55ae676
fix with warmup issue
libinta Feb 4, 2025
787700b
fix cpu test issue and format
libinta Feb 5, 2025
6f02b86
fix code format
libinta Feb 5, 2025
b97f7c6
Merge branch 'habana_main' into dev/enable_embedding_ace
libinta Feb 5, 2025
593ded0
fix hpu attn coding issue
libinta Feb 5, 2025
1185c2e
add support for batch padding
libinta Feb 5, 2025
53f94e0
Merge branch 'habana_main' into dev/enable_embedding_ace
kzawora-intel Feb 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-hpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pandas
tabulate
setuptools>=61
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@bac2a62
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@710642
12 changes: 8 additions & 4 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def __init__(

self.attn_type = attn_type
if (self.attn_type != AttentionType.DECODER
and self.attn_type != AttentionType.ENCODER_DECODER):
and self.attn_type != AttentionType.ENCODER_DECODER
and self.attn_type != AttentionType.ENCODER_ONLY):
raise NotImplementedError("Encoder self-attention "
"is not implemented for "
"HPUAttentionImpl")
Expand Down Expand Up @@ -204,7 +205,8 @@ def forward(
value = value.view(-1, self.num_kv_heads, self.head_size)
block_indices = attn_metadata.block_indices
block_offsets = attn_metadata.block_offsets
if attn_metadata.is_prompt:
if attn_metadata.is_prompt and self.attn_type \
is not AttentionType.ENCODER_ONLY:
key = key.unflatten(0, (block_indices.size(0), -1))
value = value.unflatten(0, (block_indices.size(0), -1))
if kv_cache is not None and isinstance(kv_cache, tuple):
Expand All @@ -224,6 +226,7 @@ def forward(
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
kv_shape = (batch_size, seq_len_kv, self.num_kv_heads,
self.head_size)

if attn_metadata is None or attn_metadata.block_list is None:
if not self.prefill_use_fusedsdpa:
# TODO: move this outside of model
Expand All @@ -238,7 +241,7 @@ def forward(
(1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
else:
attn_bias = None
attn_bias = attn_metadata.attn_bias

out = ops.prompt_attention(
query.view(query_shape),
Expand All @@ -251,7 +254,8 @@ def forward(
softmax_op=self.softmax,
matmul_av_op=self.matmul_av,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
fsdpa_op=self.fused_scaled_dot_product_attention,
fsdpa_op=self.fused_scaled_dot_product_attention
if self.prefill_use_fusedsdpa else None,
)
else:
# TODO: enable FusedSDPA
Expand Down
52 changes: 35 additions & 17 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def get_prompt_lens(
pooling_metadata: PoolingMetadata,
) -> torch.Tensor:
return PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
pooling_metadata, hidden_states.device
).prompt_lens, PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_offsets

def extract_states(
self,
Expand Down Expand Up @@ -107,10 +109,14 @@ def extract_states(
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)

first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1]
prompt_lens, prompt_offsets = self.get_prompt_lens(
hidden_states, pooling_metadata)
if prompt_offsets is not None:
first_token_flat_indices = prompt_offsets
else:
first_token_flat_indices = torch.zeros_like(prompt_lens)
first_token_flat_indices[1:] += torch.cumsum(prompt_lens,
dim=0)[:-1]
return hidden_states[first_token_flat_indices]


Expand All @@ -121,9 +127,15 @@ def extract_states(
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)

last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
prompt_lens, prompt_offsets = self.get_prompt_lens(
hidden_states, pooling_metadata)
if prompt_offsets is not None:
last_token_flat_indices = (torch.sum(torch.cat(
(prompt_lens.unsqueeze(0), prompt_offsets.unsqueeze(0)), 0),
dim=0,
keepdim=True) - 1).squeeze(0)
else:
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
return hidden_states[last_token_flat_indices]


Expand All @@ -134,7 +146,8 @@ def extract_states(
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_lens, prompt_offsets = self.get_prompt_lens(
hidden_states, pooling_metadata)

offset = 0
pooled_data = list[torch.Tensor]()
Expand All @@ -152,14 +165,18 @@ def extract_states(
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)

prompt_lens, prompt_offsets = self.get_prompt_lens(
hidden_states, pooling_metadata)
cumsum = torch.cumsum(hidden_states, dim=0)
start_indices = torch.cat([
torch.tensor([0], device=hidden_states.device),
torch.cumsum(prompt_lens[:-1], dim=0)
])
end_indices = torch.cumsum(prompt_lens, dim=0)
if prompt_offsets is not None:
end_indices = prompt_offsets + prompt_lens
start_indices = prompt_offsets
else:
start_indices = torch.cat([
torch.tensor([0], device=hidden_states.device),
torch.cumsum(prompt_lens[:-1], dim=0)
])
end_indices = torch.cumsum(prompt_lens, dim=0)
return (cumsum[end_indices - 1] - cumsum[start_indices] +
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)

Expand All @@ -184,7 +201,8 @@ def extract_states(
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Union[list[torch.Tensor], torch.Tensor]:
prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata)
prompt_lens, prompt_offsets = self.get_prompt_lens(
hidden_states, pooling_metadata)

returned_token_ids = self.returned_token_ids
if returned_token_ids is not None and len(returned_token_ids) > 0:
Expand Down
23 changes: 19 additions & 4 deletions vllm/model_executor/pooling_metadata.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torch

Expand All @@ -17,30 +17,36 @@ class PoolingMetadata:
seq_groups: List of (seq_ids, pooling_params).
seq_data: A mapping of sequence ID to additional sequence data.
prompt_lens: List of the lengths of each prompt.
prompt_offsets: List of prompt start offsets for each prompt
when flat out with padding
"""

def __init__(
self,
seq_groups: List[Tuple[List[int], PoolingParams]],
seq_data: Dict[int, Any], # Specific data related to sequences
prompt_lens: List[int],
prompt_offsets: Optional[List[int]] = None,
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.prompt_offsets = prompt_offsets

def __repr__(self) -> str:
return ("PoolingMetadata("
f"seq_groups={self.seq_groups}, "
f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens})")
f"prompt_lens={self.prompt_lens}, "
f"prompt_offsets={self.prompt_offsets})")


@dataclass
class PoolingTensors:
"""Tensors for pooling."""

prompt_lens: torch.Tensor
prompt_offsets: torch.Tensor

@classmethod
def from_pooling_metadata(
Expand All @@ -64,6 +70,15 @@ def from_pooling_metadata(
dtype=torch.long,
pin_memory=pin_memory,
)

if pooling_metadata.prompt_offsets is not None:
prompt_offsets_t = torch.tensor(
pooling_metadata.prompt_offsets,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
).to(device=device, non_blocking=True)
else:
prompt_offsets_t = None
return cls(prompt_lens=prompt_lens_t.to(device=device,
non_blocking=True), )
non_blocking=True),
prompt_offsets=prompt_offsets_t)
Loading
Loading