Skip to content

Commit

Permalink
Add sampling from hugging face (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
neubig authored Mar 11, 2024
1 parent 873c60e commit f147f39
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 27 deletions.
49 changes: 30 additions & 19 deletions llments/lm/base/empirical.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Any
from llments.lm.lm import LanguageModel
import random
import json
Expand All @@ -11,21 +10,37 @@ def __init__(self, data: list[str], probs: list[float] | None = None):
probs = [1 / len(data)] * len(data)
self.data = pd.DataFrame({"text": data, "prob": probs})

def generate(self, condition: str | None, **kwargs: Any) -> str:
"""Sample from the language model, possibly conditioned on a prefix."""
if condition is None:
return random.choices(self.data["text"], weights=self.data["probs"])[0]
else:
# Filter to only those that start with the condition
def generate(
self,
condition: str | None,
do_sample: bool = False,
max_length: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1,
) -> list[str]:
"""See base class."""
filtered_df = self.data
# Adjust distribution
if condition:
filtered_df = self.data[self.data["text"].str.startswith(condition)]
if filtered_df.empty:
raise ValueError(
f"Condition {condition} does not match any strings in the "
"distribution."
)
# Normalize the probabilities
filtered_df["prob"] = filtered_df["prob"] / filtered_df["prob"].sum()
return random.choices(filtered_df["text"], weights=filtered_df["probs"])[0]
if not do_sample:
raise NotImplementedError("Greedy decoding is not implemented yet.")
if max_length is not None:
filtered_df = filtered_df[
filtered_df["text"].str.split().len() <= max_length
]
if temperature != 1.0:
raise NotImplementedError("Temperature is not implemented yet.")
if filtered_df.empty:
raise ValueError(
f"Condition {condition} does not match any strings in the "
"distribution."
)
# Normalize the probabilities
filtered_df["prob"] = filtered_df["prob"] / filtered_df["prob"].sum()
return random.choices(
filtered_df["text"], weights=filtered_df["probs"], k=num_return_sequences
)[0]

def fit(self, target: LanguageModel, task_description: str | None = None):
raise ValueError(
Expand All @@ -36,10 +51,6 @@ def calculate_probability(self, x: str) -> float:
# Implementation logic
raise NotImplementedError("This is not implemented yet.")

def sample(self, condition: str | None, **kwargs) -> str:
# Implementation logic
raise NotImplementedError("This is not implemented yet.")


def load_from_text_file(text_file: str):
"""Load the distribution from a text file."""
Expand Down
51 changes: 45 additions & 6 deletions llments/lm/base/hugging_face.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from llments.lm.lm import LanguageModel
from transformers import pipeline


class HuggingFaceLM(LanguageModel):
def sample(
def __init__(
self,
condition: str | None,
**kwargs,
) -> str:
"""Generate from the language model, possibly conditioned on a prefix."""
raise NotImplementedError("This is not implemented yet.")
model: str,
device: str | None = None,
):
"""Initialize a HuggingFaceLM.
Args:
model: The name of the model.
device: The device to run the model on.
"""
self.text_generator = pipeline("text-generation", model=model, device=device)

def fit(
self, target: LanguageModel, task_description: str | None = None
Expand All @@ -25,6 +31,39 @@ def fit(
"""
raise NotImplementedError("This is not implemented yet.")

def generate(
self,
condition: str | None,
do_sample: bool = False,
max_length: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1,
) -> list[str]:
"""Generate an output given the language model.
Args:
condition: The conditioning sequence for the output.
If None, the output is not conditioned.
do_sample: Whether to use sampling or greedy decoding.
max_length: The maximum length of the output sequence,
(defaults to model max).
temperature: The value used to module the next token probabilities.
num_return_sequences: The number of independently computed returned
sequences for each element in the batch.
Returns:
str: A sampled output sequence from the language model.
"""
results = self.text_generator(
condition,
do_sample=do_sample,
max_length=max_length,
temperature=temperature,
num_return_sequences=num_return_sequences,
clean_up_tokenization_spaces=True,
)
return [res["generated_text"] for res in results]


def load_from_spec(spec_file: str) -> HuggingFaceLM:
"""Load a language model from a specification file.
Expand Down
17 changes: 15 additions & 2 deletions llments/lm/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,27 @@ def calculate_probability(self, output: str) -> float:
...

@abc.abstractmethod
def generate(self, condition: str | None) -> str:
def generate(
self,
condition: str | None,
do_sample: bool = False,
max_length: int | None = None,
temperature: float = 1.0,
num_return_sequences: int = 1,
) -> list[str]:
"""Generate an output given the language model.
Args:
condition: The conditioning sequence for the output.
If None, the output is not conditioned.
do_sample: Whether to use sampling or greedy decoding.
max_length: The maximum length of the output sequence,
(defaults to model max).
temperature: The value used to module the next token probabilities.
num_return_sequences: The number of independently computed returned
sequences for each element in the batch.
Returns:
str: A sampled output sequence from the language model.
str: Sampled output sequences from the language model.
"""
...

0 comments on commit f147f39

Please sign in to comment.