Skip to content

Commit

Permalink
update to outlines010 (#1092)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
davidberenstein1957 and pre-commit-ci[bot] authored Jan 10, 2025
1 parent d9fd15c commit 9506930
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 91 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,4 @@ venv.bak/
# Other
*.log
*.swp
.DS_Store
17 changes: 12 additions & 5 deletions src/distilabel/models/llms/huggingface/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.structured_outputs.outlines import (
_is_outlines_version_below_0_1_0,
)
from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput
from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR

Expand Down Expand Up @@ -111,6 +114,7 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):

_pipeline: Optional["Pipeline"] = PrivateAttr(...)
_prefix_allowed_tokens_fn: Union[Callable, None] = PrivateAttr(default=None)
_logits_processor: Union[Callable, None] = PrivateAttr(default=None)

def load(self) -> None:
"""Loads the model and tokenizer and creates the text generation pipeline. In addition,
Expand Down Expand Up @@ -149,9 +153,11 @@ def load(self) -> None:
self._pipeline.tokenizer.pad_token = self._pipeline.tokenizer.eos_token # type: ignore

if self.structured_output:
self._prefix_allowed_tokens_fn = self._prepare_structured_output(
self.structured_output
)
processor = self._prepare_structured_output(self.structured_output)
if _is_outlines_version_below_0_1_0():
self._prefix_allowed_tokens_fn = processor
else:
self._logits_processor = [processor]

super().load()

Expand Down Expand Up @@ -232,7 +238,8 @@ def generate( # type: ignore
do_sample=do_sample,
num_return_sequences=num_generations,
prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn,
pad_token_id=self._pipeline.tokenizer.eos_token_id, # type: ignore
pad_token_id=self._pipeline.tokenizer.eos_token_id,
logits_processor=self._logits_processor,
)
llm_output = [
[generation["generated_text"] for generation in output]
Expand Down Expand Up @@ -292,7 +299,7 @@ def get_last_hidden_states(

def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union[Callable, None]:
) -> Union[Callable, List[Callable]]:
"""Creates the appropriate function to filter tokens to generate structured outputs.
Args:
Expand Down
11 changes: 8 additions & 3 deletions src/distilabel/models/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType

if TYPE_CHECKING:
from llama_cpp import CreateChatCompletionResponse, Llama, LogitsProcessorList
from llama_cpp import (
CreateChatCompletionResponse,
Llama,
LogitsProcessor,
LogitsProcessorList,
)

from distilabel.steps.tasks.typing import FormattedInput, StandardInput

Expand Down Expand Up @@ -383,7 +388,7 @@ def generate( # type: ignore

def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union["LogitsProcessorList", None]:
) -> Union["LogitsProcessorList", "LogitsProcessor"]:
"""Creates the appropriate function to filter tokens to generate structured outputs.
Args:
Expand All @@ -399,4 +404,4 @@ def _prepare_structured_output(
result = prepare_guided_output(structured_output, "llamacpp", self._model)
if (schema := result.get("schema")) and self.structured_output:
self.structured_output["schema"] = schema
return result["processor"]
return [result["processor"]]
52 changes: 2 additions & 50 deletions src/distilabel/models/llms/mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,18 @@
Dict,
List,
Optional,
Union,
)

from pydantic import (
Field,
PrivateAttr,
validate_call,
)

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import LLM
from distilabel.models.llms.typing import GenerateOutput
from distilabel.models.llms.utils import compute_tokens, prepare_output
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.typing import (
OutlinesStructuredOutputType,
StandardInput,
)

Expand All @@ -51,8 +47,6 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
tokenizer_config: the tokenizer configuration.
model_config: the model configuration.
adapter_path: the path to the adapter.
structured_output: a dictionary containing the structured output configuration or if more
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
Expand Down Expand Up @@ -82,17 +76,10 @@ class MlxLLM(LLM, MagpieChatTemplateMixin):
tokenizer_config: Dict[str, Any] = {}
model_config: Dict[str, Any] = {}
adapter_path: Optional[str] = None
structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field(
default=None,
description="The structured output format to use across all the generations.",
)

_mlx_generate: Optional[Callable] = PrivateAttr(default=None)
_model: Optional["nn.Module"] = PrivateAttr(...)
_tokenizer: Optional["TokenizerWrapper"] = PrivateAttr(...)
_structured_output_logits_processor: Union[Callable, None] = PrivateAttr(
default=None
)

def load(self) -> None:
"""Loads the model and tokenizer and creates the text generation pipeline. In addition,
Expand All @@ -112,11 +99,6 @@ def load(self) -> None:
adapter_path=self.adapter_path,
)

if self.structured_output:
self._structured_output_logits_processor = self._prepare_structured_output(
self.structured_output
)

if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token

Expand Down Expand Up @@ -207,10 +189,6 @@ def generate(
Returns:
A list of lists of strings containing the generated responses for each input.
"""
logits_processors = []
if self._structured_output_logits_processor:
logits_processors.append(self._structured_output_logits_processor)

structured_output = None
result = []
for input in inputs:
Expand All @@ -219,13 +197,9 @@ def generate(

output: List[str] = []
for _ in range(num_generations):
if structured_output:
additional_logits_processors = self._prepare_structured_output(
structured_output
)
logits_processors.append(additional_logits_processors)
if structured_output: # will raise a NotImplementedError
self._prepare_structured_output(structured_output)
prompt = self.prepare_input(input)

generation = self._mlx_generate(
prompt=prompt,
model=self._model,
Expand Down Expand Up @@ -264,25 +238,3 @@ def generate(
)
)
return result

def _prepare_structured_output(
self, structured_output: Optional[OutlinesStructuredOutputType] = None
) -> Union[Callable, None]:
"""Creates the appropriate function to filter tokens to generate structured outputs.
Args:
structured_output: the configuration dict to prepare the structured output.
Returns:
The callable that will be used to guide the generation of the model.
"""
from distilabel.steps.tasks.structured_outputs.outlines import (
prepare_guided_output,
)

result = prepare_guided_output(
structured_output, "transformers", self._pipeline
)
if schema := result.get("schema"):
self.structured_output["schema"] = schema
return result["processor"]
119 changes: 91 additions & 28 deletions src/distilabel/steps/tasks/structured_outputs/outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,38 @@
Literal,
Tuple,
Type,
Union,
get_args,
)

import pkg_resources
from pydantic import BaseModel

from distilabel.errors import DistilabelUserError
from distilabel.steps.tasks.structured_outputs.utils import schema_as_dict

if TYPE_CHECKING:
from distilabel.steps.tasks.typing import OutlinesStructuredOutputType
if TYPE_CHECKING: # noqa
from llama_cpp import Llama # noqa
from transformers import Pipeline # noqa
from vllm import LLM as _vLLM # noqa

from distilabel.steps.tasks.typing import OutlinesStructuredOutputType # noqa

Frameworks = Literal["transformers", "llamacpp", "vllm"]
"""Available frameworks for the structured output configuration. """


def _is_outlines_version_below_0_1_0() -> bool:
"""Helper function to check outlines availability and version.
Returns:
bool: True if outlines is not installed or version is below 0.1.0
"""
if not importlib.util.find_spec("outlines"):
raise ImportError(
"Outlines is not installed. Please install it using `pip install outlines`."
)
version = pkg_resources.get_distribution("outlines").version
return pkg_resources.parse_version(version) < pkg_resources.parse_version("0.1.0")


def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]:
Expand All @@ -45,38 +64,77 @@ def model_to_schema(schema: Type[BaseModel]) -> Dict[str, Any]:


def _get_logits_processor(framework: Frameworks) -> Tuple[Callable, Callable]:
"""Helper function to return the appropriate logits processor for the given framework."""
if framework == "transformers":
from outlines.integrations.transformers import (
JSONPrefixAllowedTokens,
RegexPrefixAllowedTokens,
"""Helper function to return the appropriate logits processors for the given framework."""
if _is_outlines_version_below_0_1_0():
processors = {
"transformers": (
"outlines.integrations.transformers",
"JSONPrefixAllowedTokens",
"RegexPrefixAllowedTokens",
),
"llamacpp": (
"outlines.integrations.llamacpp",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
"vllm": (
"outlines.integrations.vllm",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
}
else:
processors = {
"transformers": (
"outlines.processors",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
"llamacpp": (
"outlines.processors",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
"vllm": (
"outlines.processors",
"JSONLogitsProcessor",
"RegexLogitsProcessor",
),
}

if framework not in processors:
raise DistilabelUserError(
f"Invalid framework '{framework}'. Must be one of {get_args(Frameworks)}",
page="sections/how_to_guides/advanced/structured_generation/",
)

return JSONPrefixAllowedTokens, RegexPrefixAllowedTokens
module_path, json_cls, regex_cls = processors[framework]
module = importlib.import_module(module_path)
return getattr(module, json_cls), getattr(module, regex_cls)


def _get_tokenizer_from_model(
llm: Union["_vLLM", "Pipeline", "Llama"],
framework: Frameworks,
) -> Callable:
if framework == "llamacpp":
from outlines.integrations.llamacpp import (
JSONLogitsProcessor,
RegexLogitsProcessor,
)
from outlines.models.llamacpp import LlamaCppTokenizer

return JSONLogitsProcessor, RegexLogitsProcessor
return LlamaCppTokenizer(llm)
if framework == "transformers":
from outlines.models.transformers import TransformerTokenizer

return TransformerTokenizer(llm.tokenizer)
if framework == "vllm":
from outlines.integrations.vllm import JSONLogitsProcessor, RegexLogitsProcessor
from outlines.models.vllm import adapt_tokenizer

return JSONLogitsProcessor, RegexLogitsProcessor

raise DistilabelUserError(
f"Invalid framework '{framework}'. Must be one of {get_args(Frameworks)}",
page="sections/how_to_guides/advanced/structured_generation/",
)
return adapt_tokenizer(llm.get_tokenizer())


def prepare_guided_output(
structured_output: "OutlinesStructuredOutputType",
framework: Frameworks,
llm: Any,
llm: Union["_vLLM", "Pipeline", "Llama"],
) -> Dict[str, Any]:
"""Prepares the `LLM` to generate guided output using `outlines`.
Expand All @@ -97,10 +155,6 @@ def prepare_guided_output(
case of "json" will also include the schema as a dict, to simplify serialization
and deserialization.
"""
if not importlib.util.find_spec("outlines"):
raise ImportError(
"Outlines is not installed. Please install it using `pip install 'distilabel[outlines]'`."
)

json_processor, regex_processor = _get_logits_processor(framework)

Expand All @@ -116,18 +170,27 @@ def prepare_guided_output(
elif isinstance(schema, str):
format = "regex"

if _is_outlines_version_below_0_1_0():
# use the llm for processor initialization
model = llm
tokenizer = None
else:
# use the tokenizer for processor initialization
model = None
tokenizer = _get_tokenizer_from_model(llm, framework)

if format == "json":
return {
"processor": json_processor(
schema,
llm,
model or tokenizer,
whitespace_pattern=structured_output.get("whitespace_pattern"),
),
"schema": schema_as_dict(schema),
}

if format == "regex":
return {"processor": regex_processor(schema, llm)}
return {"processor": regex_processor(schema, model or tokenizer)}

raise DistilabelUserError(
f"Invalid format '{format}'. Must be either 'json' or 'regex'.",
Expand Down
Loading

0 comments on commit 9506930

Please sign in to comment.