Skip to content

Commit

Permalink
revert test file
Browse files Browse the repository at this point in the history
  • Loading branch information
V2arK committed Jan 17, 2025
1 parent 9bce587 commit 79d74a9
Showing 1 changed file with 3 additions and 56 deletions.
59 changes: 3 additions & 56 deletions llama_stack/providers/tests/inference/test_text_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
UserMessage,
)
from llama_stack.apis.models import Model

from .utils import group_chunks


Expand All @@ -42,28 +43,6 @@
# --env FIREWORKS_API_KEY=<your_api_key>


def skip_if_centml_tool_call(provider):
"""
Skip tool-calling tests if the provider is remote::centml,
because CentML currently doesn't generate tool_call responses.
"""
if provider.__provider_spec__.provider_type == "remote::centml":
pytest.skip(
"CentML does not currently return tool calls. Skipping tool-calling test."
)


def skip_if_centml_and_8b(inference_model, inference_impl):
"""
Skip if provider is CentML and the model is 8B.
CentML only supports 'meta-llama/Llama-3.2-3B-Instruct'.
"""
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type == "remote::centml" and "8b" in inference_model.lower(
):
pytest.skip("CentML does not support Llama-3.1 8B model.")


def get_expected_stop_reason(model: str):
return (
StopReason.end_of_message
Expand Down Expand Up @@ -111,11 +90,7 @@ class TestInference:
# share the same provider instance.
@pytest.mark.asyncio(loop_scope="session")
async def test_model_list(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack

# Skip if 8B + CentML
skip_if_centml_and_8b(inference_model, inference_impl)

_, models_impl = inference_stack
response = await models_impl.list_models()
assert isinstance(response, list)
assert len(response) >= 1
Expand All @@ -133,9 +108,6 @@ async def test_model_list(self, inference_model, inference_stack):
async def test_completion(self, inference_model, inference_stack):
inference_impl, _ = inference_stack

# Skip if 8B + CentML
skip_if_centml_and_8b(inference_model, inference_impl)

provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
"inline::meta-reference",
Expand Down Expand Up @@ -181,9 +153,6 @@ async def test_completion(self, inference_model, inference_stack):
async def test_completion_logprobs(self, inference_model, inference_stack):
inference_impl, _ = inference_stack

# Skip if 8B + CentML
skip_if_centml_and_8b(inference_model, inference_impl)

provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
# "remote::nvidia", -- provider doesn't provide all logprobs
Expand Down Expand Up @@ -287,10 +256,6 @@ async def test_chat_completion_non_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
inference_impl, _ = inference_stack

# Skip if 8B + CentML
skip_if_centml_and_8b(inference_model, inference_impl)

response = await inference_impl.chat_completion(
model_id=inference_model,
messages=sample_messages,
Expand Down Expand Up @@ -379,10 +344,6 @@ async def test_chat_completion_streaming(
self, inference_model, inference_stack, common_params, sample_messages
):
inference_impl, _ = inference_stack

# Skip if 8B + CentML
skip_if_centml_and_8b(inference_model, inference_impl)

response = [
r
async for r in await inference_impl.chat_completion(
Expand Down Expand Up @@ -416,13 +377,6 @@ async def test_chat_completion_with_tool_calling(
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)

# Skip if 8B + CentML
skip_if_centml_and_8b(inference_model, inference_impl)

# Skip if CentML (it doesn't produce tool calls yet)
skip_if_centml_tool_call(provider)

if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
Expand Down Expand Up @@ -470,13 +424,6 @@ async def test_chat_completion_with_tool_calling_streaming(
):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)

# Skip if 8B + CentML
skip_if_centml_and_8b(inference_model, inference_impl)

# Skip if CentML (it doesn't produce tool calls yet)
skip_if_centml_tool_call(provider)

if (
provider.__provider_spec__.provider_type == "remote::groq"
and "Llama-3.2" in inference_model
Expand Down Expand Up @@ -530,7 +477,7 @@ async def test_chat_completion_with_tool_calling_streaming(
last = grouped[ChatCompletionResponseEventType.progress][-1]
# assert last.event.stop_reason == expected_stop_reason
assert last.event.delta.parse_status == ToolCallParseStatus.succeeded
assert last.event.delta.content.type == "tool_call"
assert isinstance(last.event.delta.content, ToolCall)

call = last.event.delta.content
assert call.tool_name == "get_weather"
Expand Down

0 comments on commit 79d74a9

Please sign in to comment.