Skip to content

Commit

Permalink
Added example and test for generating structured data to MlxLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
dameikle committed Jan 22, 2025
1 parent e24470c commit ad91b0b
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 1 deletion.
24 changes: 24 additions & 0 deletions src/distilabel/models/llms/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
sent to the LLM to generate an instruction or a follow up user message. Valid
values are "llama3", "qwen2" or another pre-query template provided. Defaults
to `None`.
structured_output: a dictionary containing the structured output configuration or if more
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
Icon:
`:apple:`
Expand All @@ -82,6 +84,28 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]])
```
Generate structured data:
```python
from pathlib import Path
from distilabel.models.llms import MlxLLM
from pydantic import BaseModel
class User(BaseModel):
first_name: str
last_name: str
llm = MlxLLM(
path_or_hf_repo="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
structured_output={"format": "json", "schema": User},
)
llm.load()
# Call the model
output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for John Smith"}]])
```
"""

path_or_hf_repo: str
Expand Down
49 changes: 48 additions & 1 deletion tests/unit/models/llms/test_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import platform
from typing import Any, Dict, Generator

import pytest
from pydantic import BaseModel

from distilabel.models.llms.mlx import MlxLLM

Expand Down Expand Up @@ -63,6 +64,52 @@ def test_generate(self, llm: MlxLLM) -> None:
assert "input_tokens" in statistics
assert "output_tokens" in statistics

def test_structured_generation_json(self, llm: MlxLLM) -> None:

class User(BaseModel):
first_name: str
last_name: str

llm.structured_output = {
"format": "json",
"schema": User.model_json_schema()
}

responses = llm.generate(
inputs=[
[{"role": "user",
"content": "Create a user profile for John Smith"}],
],
num_generations=1
)

assert len(responses) == 1
assert "generations" in responses[0]
assert "statistics" in responses[0]
generations = responses[0]["generations"]
assert len(generations) == 1

# Clean and parse the generation
for generation in generations:
# Remove the <|im_end|> tokens and clean up the string
cleaned_json = generation.replace('<|im_end|>', '').strip()
try:
user_data = json.loads(cleaned_json)
parsed_user = User(**user_data)
assert isinstance(parsed_user, User)
assert parsed_user.first_name == "John"
assert parsed_user.last_name == "Smith"
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
print(f"Raw generation: {cleaned_json}")
raise
except ValueError as e:
print(f"Validation error: {e}")
raise
statistics = responses[0]["statistics"]
assert "input_tokens" in statistics
assert "output_tokens" in statistics

@pytest.mark.parametrize(
"structured_output, dump",
[
Expand Down

0 comments on commit ad91b0b

Please sign in to comment.