diff --git a/llments/lm/base/empirical.py b/llments/lm/base/empirical.py index a77aef0..e9e8399 100644 --- a/llments/lm/base/empirical.py +++ b/llments/lm/base/empirical.py @@ -1,4 +1,3 @@ -from typing import Any from llments.lm.lm import LanguageModel import random import json @@ -11,21 +10,37 @@ def __init__(self, data: list[str], probs: list[float] | None = None): probs = [1 / len(data)] * len(data) self.data = pd.DataFrame({"text": data, "prob": probs}) - def generate(self, condition: str | None, **kwargs: Any) -> str: - """Sample from the language model, possibly conditioned on a prefix.""" - if condition is None: - return random.choices(self.data["text"], weights=self.data["probs"])[0] - else: - # Filter to only those that start with the condition + def generate( + self, + condition: str | None, + do_sample: bool = False, + max_length: int | None = None, + temperature: float = 1.0, + num_return_sequences: int = 1, + ) -> list[str]: + """See base class.""" + filtered_df = self.data + # Adjust distribution + if condition: filtered_df = self.data[self.data["text"].str.startswith(condition)] - if filtered_df.empty: - raise ValueError( - f"Condition {condition} does not match any strings in the " - "distribution." - ) - # Normalize the probabilities - filtered_df["prob"] = filtered_df["prob"] / filtered_df["prob"].sum() - return random.choices(filtered_df["text"], weights=filtered_df["probs"])[0] + if not do_sample: + raise NotImplementedError("Greedy decoding is not implemented yet.") + if max_length is not None: + filtered_df = filtered_df[ + filtered_df["text"].str.split().len() <= max_length + ] + if temperature != 1.0: + raise NotImplementedError("Temperature is not implemented yet.") + if filtered_df.empty: + raise ValueError( + f"Condition {condition} does not match any strings in the " + "distribution." + ) + # Normalize the probabilities + filtered_df["prob"] = filtered_df["prob"] / filtered_df["prob"].sum() + return random.choices( + filtered_df["text"], weights=filtered_df["probs"], k=num_return_sequences + )[0] def fit(self, target: LanguageModel, task_description: str | None = None): raise ValueError( @@ -36,10 +51,6 @@ def calculate_probability(self, x: str) -> float: # Implementation logic raise NotImplementedError("This is not implemented yet.") - def sample(self, condition: str | None, **kwargs) -> str: - # Implementation logic - raise NotImplementedError("This is not implemented yet.") - def load_from_text_file(text_file: str): """Load the distribution from a text file.""" diff --git a/llments/lm/base/hugging_face.py b/llments/lm/base/hugging_face.py index 7ed6eb3..52ae6f3 100644 --- a/llments/lm/base/hugging_face.py +++ b/llments/lm/base/hugging_face.py @@ -1,14 +1,20 @@ from llments.lm.lm import LanguageModel +from transformers import pipeline class HuggingFaceLM(LanguageModel): - def sample( + def __init__( self, - condition: str | None, - **kwargs, - ) -> str: - """Generate from the language model, possibly conditioned on a prefix.""" - raise NotImplementedError("This is not implemented yet.") + model: str, + device: str | None = None, + ): + """Initialize a HuggingFaceLM. + + Args: + model: The name of the model. + device: The device to run the model on. + """ + self.text_generator = pipeline("text-generation", model=model, device=device) def fit( self, target: LanguageModel, task_description: str | None = None @@ -25,6 +31,39 @@ def fit( """ raise NotImplementedError("This is not implemented yet.") + def generate( + self, + condition: str | None, + do_sample: bool = False, + max_length: int | None = None, + temperature: float = 1.0, + num_return_sequences: int = 1, + ) -> list[str]: + """Generate an output given the language model. + + Args: + condition: The conditioning sequence for the output. + If None, the output is not conditioned. + do_sample: Whether to use sampling or greedy decoding. + max_length: The maximum length of the output sequence, + (defaults to model max). + temperature: The value used to module the next token probabilities. + num_return_sequences: The number of independently computed returned + sequences for each element in the batch. + + Returns: + str: A sampled output sequence from the language model. + """ + results = self.text_generator( + condition, + do_sample=do_sample, + max_length=max_length, + temperature=temperature, + num_return_sequences=num_return_sequences, + clean_up_tokenization_spaces=True, + ) + return [res["generated_text"] for res in results] + def load_from_spec(spec_file: str) -> HuggingFaceLM: """Load a language model from a specification file. diff --git a/llments/lm/lm.py b/llments/lm/lm.py index 5989ae1..862a708 100644 --- a/llments/lm/lm.py +++ b/llments/lm/lm.py @@ -15,14 +15,27 @@ def calculate_probability(self, output: str) -> float: ... @abc.abstractmethod - def generate(self, condition: str | None) -> str: + def generate( + self, + condition: str | None, + do_sample: bool = False, + max_length: int | None = None, + temperature: float = 1.0, + num_return_sequences: int = 1, + ) -> list[str]: """Generate an output given the language model. Args: condition: The conditioning sequence for the output. If None, the output is not conditioned. + do_sample: Whether to use sampling or greedy decoding. + max_length: The maximum length of the output sequence, + (defaults to model max). + temperature: The value used to module the next token probabilities. + num_return_sequences: The number of independently computed returned + sequences for each element in the batch. Returns: - str: A sampled output sequence from the language model. + str: Sampled output sequences from the language model. """ ...