Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Add tokenizer #394

Merged
merged 9 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from transformers import PreTrainedTokenizer

import vllm.envs as envs
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
Expand Down Expand Up @@ -924,6 +925,14 @@ async def get_model_config(self) -> ModelConfig:
else:
return self.engine.get_model_config()

async def get_parallel_config(self) -> ParallelConfig:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can these new methods go into the VLLMBackend protocol as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think they should be in the protocol because the Protocol does not have to implement these + most of the time the Protocol will not implement these

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah I see these are only on the AsyncLLMEngine, 🌶️

"""Get the parallel configuration of the vLLM engine."""
if self.engine_use_ray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these ifs are outta control, the ray engine should totally be a separate VLLMBackend 😉

...a change for another day, or week

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea

return await self.engine.get_parallel_config.remote( # type: ignore
)
else:
return self.engine.get_parallel_config()

async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
if self.engine_use_ray:
Expand All @@ -932,6 +941,22 @@ async def get_decoding_config(self) -> DecodingConfig:
else:
return self.engine.get_decoding_config()

async def get_scheduler_config(self) -> SchedulerConfig:
"""Get the scheduling configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_scheduler_config.remote( # type: ignore
)
else:
return self.engine.get_scheduler_config()

async def get_lora_config(self) -> LoRAConfig:
"""Get the lora configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_lora_config.remote( # type: ignore
)
else:
return self.engine.get_lora_config()

async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
Expand Down
35 changes: 20 additions & 15 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, _init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
Expand Down Expand Up @@ -481,19 +481,12 @@ def get_tokenizer_for_seq(self,
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)

def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)

return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
**init_kwargs)
def _init_tokenizer(self) -> BaseTokenizerGroup:
return _init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
enable_lora=bool(self.lora_config))

def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
Expand Down Expand Up @@ -755,10 +748,22 @@ def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config

def get_parallel_config(self) -> ParallelConfig:
"""Gets the parallel configuration."""
return self.parallel_config

def get_decoding_config(self) -> DecodingConfig:
"""Gets the decoding configuration."""
return self.decoding_config

def get_scheduler_config(self) -> SchedulerConfig:
"""Gets the scheduler configuration."""
return self.scheduler_config

def get_lora_config(self) -> LoRAConfig:
"""Gets the LoRA configuration."""
return self.lora_config

def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return sum(scheduler.get_num_unfinished_seq_groups()
Expand Down
8 changes: 2 additions & 6 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
from starlette.routing import Mount
from transformers import AutoTokenizer

import vllm.envs as envs
from vllm.config import ModelConfig
Expand Down Expand Up @@ -115,11 +114,8 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]:
rpc_server_process.start()

## Then build the client for the backend process
# TODO: figure out a way around passing the tokenizer
backend = RPCClient(tokenizer=AutoTokenizer.from_pretrained(
args.model),
port=port)
await backend.wait_for_server()
backend = RPCClient(port)
await backend.setup()

try:
yield backend
Expand Down
8 changes: 6 additions & 2 deletions vllm/entrypoints/openai/rpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@ class RPCAbortRequest:
class RPCUtilityRequest(Enum):
IS_SERVER_READY = 1
GET_MODEL_CONFIG = 2
DO_LOG_STATS = 3
CHECK_HEALTH = 4
GET_DECODING_CONFIG = 3
GET_PARALLEL_CONFIG = 4
GET_SCHEDULER_CONFIG = 5
GET_LORA_CONFIG = 6
DO_LOG_STATS = 7
CHECK_HEALTH = 8


RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest,
Expand Down
115 changes: 88 additions & 27 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pickle
from typing import AsyncIterator, Mapping, Optional
from typing import Any, AsyncIterator, Mapping, Optional

import zmq
import zmq.asyncio

from vllm.config import DecodingConfig, ModelConfig
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE,
VLLM_RPC_HEALTHY_STR,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
Expand All @@ -14,24 +15,65 @@
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import (
_init_tokenizer_from_configs)


class RPCClient:

# TODO: check if opening all these sockets is an antipattern?
def __init__(self, tokenizer, port: int):
# ZMQ context.
def __init__(self, port: int):
self.context = zmq.asyncio.Context()

# TODO: do the tokenizer properly.
self.tokenizer = tokenizer
self.decoding_config = DecodingConfig()
self.path = f"tcp://localhost:{port}"

async def setup(self):
"""Setup the client before it starts sending server requests."""

# Wait until server is ready.
await self.wait_for_server()

# Get the configs.
self.model_config = await self._get_model_config_rpc()
self.decoding_config = await self._get_decoding_config_rpc()

# Create the tokenizer group.
# Note: this is a hack until we fully
self.tokenizer = _init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=(await self._get_scheduler_config_rpc()),
parallel_config=(await self._get_parallel_config_rpc()),
enable_lora=bool(await self._get_lora_config_rpc()),
)

def close(self):
"""Destroy the ZeroMQ Context."""
self.context.destroy()

async def _send_get_data_rpc_request(self, request: RPCUtilityRequest,
expected_type: Any,
error_message: str) -> Any:
"""Send an RPC request that is expecting data back."""

# Connect to socket.
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(self.path)

# Ping RPCServer with a request.
await socket.send(pickle.dumps(request))

# Await the data from the Server.
data = pickle.loads(await socket.recv())
if not isinstance(data, expected_type):
# LoRAConfig can be None.
if expected_type == LoRAConfig and data is None:
pass
else:
socket.close()
raise ValueError(error_message)

socket.close()

return data

async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
error_message: str):
"""Send one-way RPC request to trigger an action."""
Expand All @@ -55,13 +97,14 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE,
return response

async def get_tokenizer(self, lora_request: LoRARequest):
# TODO: handle this via get data? - or avoid doing via RPC
return self.tokenizer
return await self.tokenizer.get_lora_tokenizer_async(lora_request)

async def get_decoding_config(self):
# TODO: handle this via get data? - or avoid doing via RPC
return self.decoding_config

async def get_model_config(self):
return self.model_config

async def is_tracing_enabled(self):
# TODO: what is this?
return False
Expand All @@ -73,30 +116,48 @@ async def wait_for_server(self):
request=RPCUtilityRequest.IS_SERVER_READY,
error_message="Unable to start RPC Server.")

async def get_model_config(self) -> ModelConfig:
async def _get_model_config_rpc(self) -> ModelConfig:
"""Get the ModelConfig object from the RPC Server"""

# Connect to socket.
socket = self.context.socket(zmq.constants.DEALER)
socket.connect(self.path)
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_MODEL_CONFIG,
expected_type=ModelConfig,
error_message="Could not get ModelConfig from RPC Server")

# Ping RPCServer with GET_MODEL_CONFIG request.
await socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG))
async def _get_decoding_config_rpc(self) -> DecodingConfig:
"""Get DecodingConfig from the RPCServer"""

# Await the MODEL_CONFIG from the Server.
model_config = pickle.loads(await socket.recv())
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_DECODING_CONFIG,
expected_type=DecodingConfig,
error_message="Could not get DecodingConfig from RPC Server")

if not isinstance(model_config, ModelConfig):
socket.close()
raise ValueError("Expected ModelConfig object from RPC, but "
f"got {model_config}")
async def _get_parallel_config_rpc(self) -> ParallelConfig:
"""Get ParallelConfig from the RPCServer"""

socket.close()
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_PARALLEL_CONFIG,
expected_type=ParallelConfig,
error_message="Could not get ModelConfig from RPC Server")

async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
"""Get SchedulerConfig from the RPCServer"""

return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_SCHEDULER_CONFIG,
expected_type=SchedulerConfig,
error_message="Could not get SchedulerConfig from RPC Server")

async def _get_lora_config_rpc(self):
"""Get LoRAConfig from the RPCServer"""

return model_config
return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_LORA_CONFIG,
expected_type=LoRAConfig,
error_message="Could not get LoRAConfig from RPC Server")

async def abort(self, request_id: str):
"""Send an RPCAbortRequest to the RPC Server"""
"""Send an ABORT_REQUEST signal to the RPC Server"""

await self._send_one_way_rpc_request(
request=RPCAbortRequest(request_id),
Expand Down
51 changes: 43 additions & 8 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,47 @@ def cleanup(self):
self.socket.close()
self.context.destroy()

async def _send_success_message(self, identity):
"""Send message to client indicating an action was successful."""
await self.socket.send_multipart([
identity,
pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL),
])

async def get_model_config(self, identity):
"""Send the ModelConfig """
"""Send the ModelConfig"""
model_config = await self.engine.get_model_config()

await self.socket.send_multipart(
[identity,
pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL)])

async def get_decoding_config(self, identity):
"""Send the DecodingConfig"""
decoding_config = await self.engine.get_decoding_config()

await self.socket.send_multipart(
[identity,
pickle.dumps(decoding_config, pickle.HIGHEST_PROTOCOL)])

async def get_lora_config(self, identity):
lora_config = await self.engine.get_lora_config()

await self.socket.send_multipart(
[identity,
pickle.dumps(lora_config, pickle.HIGHEST_PROTOCOL)])

async def get_scheduler_config(self, identity):
"""Send the SchedulerConfig"""
parallel_config = await self.engine.get_scheduler_config()

await self.socket.send_multipart(
[identity,
pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)])

async def get_parallel_config(self, identity):
"""Send the ParallelConfig"""
parallel_config = await self.engine.get_parallel_config()

await self.socket.send_multipart(
[identity,
pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)])

async def do_log_stats(self, identity):
"""Log stats and confirm success."""
await self.engine.do_log_stats()

await self.socket.send_multipart([
Expand All @@ -61,12 +86,14 @@ async def do_log_stats(self, identity):
])

async def is_server_ready(self, identity):
"""Notify the client that we are ready."""
await self.socket.send_multipart([
identity,
pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL),
])

async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
# Abort the request in the llm engine.
await self.engine.abort(request.request_id)

Expand Down Expand Up @@ -120,6 +147,14 @@ def _make_handler_coro(self, identity,
elif isinstance(request, RPCUtilityRequest):
if request == RPCUtilityRequest.GET_MODEL_CONFIG:
return self.get_model_config(identity)
elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG:
return self.get_parallel_config(identity)
elif request == RPCUtilityRequest.GET_DECODING_CONFIG:
return self.get_decoding_config(identity)
elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG:
return self.get_scheduler_config(identity)
elif request == RPCUtilityRequest.GET_LORA_CONFIG:
return self.get_lora_config(identity)
elif request == RPCUtilityRequest.DO_LOG_STATS:
return self.do_log_stats(identity)
elif request == RPCUtilityRequest.IS_SERVER_READY:
Expand Down
Loading