Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Add tokenizer (#394)
Browse files Browse the repository at this point in the history
SUMMARY:
* add endpoints to request `ModelConfig`, `SchedulerConfig`,
`LoRAConfig`, `ParallelConfig`
* factor out tokenizer group creation function to be a utility function
* create tokenizer_group on client side
  • Loading branch information
robertgshaw2-redhat authored Jul 31, 2024
1 parent 98a7dab commit f5f0b45
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 61 deletions.
1 change: 1 addition & 0 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str,
choice = completion.choices[0]
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
print(completion.usage)
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5,
prompt_tokens=6 + num_virtual_tokens,
Expand Down
27 changes: 26 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from transformers import PreTrainedTokenizer

import vllm.envs as envs
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
Expand Down Expand Up @@ -924,6 +925,14 @@ async def get_model_config(self) -> ModelConfig:
else:
return self.engine.get_model_config()

async def get_parallel_config(self) -> ParallelConfig:
"""Get the parallel configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_parallel_config.remote( # type: ignore
)
else:
return self.engine.get_parallel_config()

async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
if self.engine_use_ray:
Expand All @@ -932,6 +941,22 @@ async def get_decoding_config(self) -> DecodingConfig:
else:
return self.engine.get_decoding_config()

async def get_scheduler_config(self) -> SchedulerConfig:
"""Get the scheduling configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_scheduler_config.remote( # type: ignore
)
else:
return self.engine.get_scheduler_config()

async def get_lora_config(self) -> LoRAConfig:
"""Get the lora configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_lora_config.remote( # type: ignore
)
else:
return self.engine.get_lora_config()

async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
Expand Down
35 changes: 20 additions & 15 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
Expand Down Expand Up @@ -481,19 +481,12 @@ def get_tokenizer_for_seq(self,
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)

def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)

return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
**init_kwargs)
def _init_tokenizer(self) -> BaseTokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
enable_lora=bool(self.lora_config))

def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
Expand Down Expand Up @@ -755,10 +748,22 @@ def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config

def get_parallel_config(self) -> ParallelConfig:
"""Gets the parallel configuration."""
return self.parallel_config

def get_decoding_config(self) -> DecodingConfig:
"""Gets the decoding configuration."""
return self.decoding_config

def get_scheduler_config(self) -> SchedulerConfig:
"""Gets the scheduler configuration."""
return self.scheduler_config

def get_lora_config(self) -> LoRAConfig:
"""Gets the LoRA configuration."""
return self.lora_config

def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return sum(scheduler.get_num_unfinished_seq_groups()
Expand Down
8 changes: 2 additions & 6 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from starlette.routing import Mount
from transformers import AutoTokenizer

import vllm.envs as envs
from vllm.config import ModelConfig
Expand Down Expand Up @@ -115,11 +114,8 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]:
rpc_server_process.start()

## Then build the client for the backend process
# TODO: figure out a way around passing the tokenizer
backend = RPCClient(tokenizer=AutoTokenizer.from_pretrained(
args.model),
port=port)
await backend.wait_for_server()
backend = RPCClient(port)
await backend.setup()

try:
yield backend
Expand Down
8 changes: 6 additions & 2 deletions vllm/entrypoints/openai/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ class RPCAbortRequest:
class RPCUtilityRequest(Enum):
IS_SERVER_READY = 1
GET_MODEL_CONFIG = 2
DO_LOG_STATS = 3
CHECK_HEALTH = 4
GET_DECODING_CONFIG = 3
GET_PARALLEL_CONFIG = 4
GET_SCHEDULER_CONFIG = 5
GET_LORA_CONFIG = 6
DO_LOG_STATS = 7
CHECK_HEALTH = 8


RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
Expand Down
114 changes: 87 additions & 27 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pickle
from typing import AsyncIterator, Mapping, Optional
from typing import Any, AsyncIterator, Mapping, Optional

import zmq
import zmq.asyncio

from vllm.config import DecodingConfig, ModelConfig
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
Expand All @@ -14,24 +15,64 @@
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs


class RPCClient:

# TODO: check if opening all these sockets is an antipattern?
def __init__(self, tokenizer, port: int):
# ZMQ context.
def __init__(self, port: int):
self.context = zmq.asyncio.Context()

# TODO: do the tokenizer properly.
self.tokenizer = tokenizer
self.decoding_config = DecodingConfig()
self.path = f"tcp://localhost:{port}"

async def setup(self):
"""Setup the client before it starts sending server requests."""

# Wait until server is ready.
await self.wait_for_server()

# Get the configs.
self.model_config = await self._get_model_config_rpc()
self.decoding_config = await self._get_decoding_config_rpc()

# Create the tokenizer group.
# TODO: refactor OAI server to avoid needing this info.
self.tokenizer = init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=(await self._get_scheduler_config_rpc()),
parallel_config=(await self._get_parallel_config_rpc()),
enable_lora=bool(await self._get_lora_config_rpc()),
)

def close(self):
"""Destroy the ZeroMQ Context."""
self.context.destroy()

async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
expected_type: Any,
error_message: str) -> Any:
"""Send an RPC request that is expecting data back."""

# Connect to socket.
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(self.path)

# Ping RPCServer with a request.
await socket.send(pickle.dumps(request))

# Await the data from the Server.
data = pickle.loads(await socket.recv())
if not isinstance(data, expected_type):
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
pass
else:
socket.close()
raise ValueError(error_message)

socket.close()

return data

async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
error_message: str):
"""Send one-way RPC request to trigger an action."""
Expand All @@ -55,13 +96,14 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
return response

async def get_tokenizer(self, lora_request: LoRARequest):
# TODO: handle this via get data? - or avoid doing via RPC
return self.tokenizer
return await self.tokenizer.get_lora_tokenizer_async(lora_request)

async def get_decoding_config(self):
# TODO: handle this via get data? - or avoid doing via RPC
return self.decoding_config

async def get_model_config(self):
return self.model_config

async def is_tracing_enabled(self):
# TODO: what is this?
return False
Expand All @@ -73,30 +115,48 @@ async def wait_for_server(self):
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server.")

async def get_model_config(self) -> ModelConfig:
async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""

# Connect to socket.
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(self.path)
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_MODEL_CONFIG,
expected_type=ModelConfig,
error_message="Could not get ModelConfig from RPC Server")

# Ping RPCServer with GET_MODEL_CONFIG request.
await socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG))
async def _get_decoding_config_rpc(self) -> DecodingConfig:
"""Get DecodingConfig from the RPCServer"""

# Await the MODEL_CONFIG from the Server.
model_config = pickle.loads(await socket.recv())
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_DECODING_CONFIG,
expected_type=DecodingConfig,
error_message="Could not get DecodingConfig from RPC Server")

if not isinstance(model_config, ModelConfig):
socket.close()
raise ValueError("Expected ModelConfig object from RPC, but "
f"got {model_config}")
async def _get_parallel_config_rpc(self) -> ParallelConfig:
"""Get ParallelConfig from the RPCServer"""

socket.close()
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_PARALLEL_CONFIG,
expected_type=ParallelConfig,
error_message="Could not get ModelConfig from RPC Server")

async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
"""Get SchedulerConfig from the RPCServer"""

return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
expected_type=SchedulerConfig,
error_message="Could not get SchedulerConfig from RPC Server")

async def _get_lora_config_rpc(self):
"""Get LoRAConfig from the RPCServer"""

return model_config
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_LORA_CONFIG,
expected_type=LoRAConfig,
error_message="Could not get LoRAConfig from RPC Server")

async def abort(self, request_id: str):
"""Send an RPCAbortRequest to the RPC Server"""
"""Send an ABORT_REQUEST signal to the RPC Server"""

await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
Expand Down
Loading

0 comments on commit f5f0b45

Please sign in to comment.