Skip to content

Commit

Permalink
Merge pull request #51 from mlfoundations/jean/curator_model
Browse files Browse the repository at this point in the history
Add curator to handle inference for the model being evaluated
  • Loading branch information
neginraoof authored Jan 30, 2025
2 parents 7b2077b + e31b2ce commit 3f6a677
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,4 @@ eval/chat_benchmarks/MixEval/mix_eval/data/test/
eval/chat_benchmarks/MixEval/mix_eval/data/model_responses
eval/chat_benchmarks/MixEval/mix_eval/eval_scripts
eval/chat_benchmarks/MixEval/results
results/
7 changes: 6 additions & 1 deletion eval/chat_benchmarks/WildBench/eval_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,13 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
# Load data
id_strs, chat_history, extracted_chats, metadata = self.load_dataset()

# Remove extra fields that might not be compatible with some models
simplified_extracted_chats = [
[{"role": c["role"], "content": c["content"]} for c in chat] for chat in extracted_chats
]

# Prepare model inputs
model_inputs = [model.apply_chat_template(chat) for chat in extracted_chats]
model_inputs = [model.apply_chat_template(chat) for chat in simplified_extracted_chats]

# Create temporary directory
temp_dir_obj = tempfile.TemporaryDirectory()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
,np.tanh(std_delta_len),instruction_difficulty,not_gamed_baseline.astype(float)
"model_curator_model_args_pretrained=claude-3-5-haiku-20241022,tokenized_requests=False",-1.8366974017900992,0.7446564084265157,-7.2599824829397441
"model_curator_model_args_pretrained=claude-3-5-haiku-20241022,tokenized_requests=False,batch_size=1",-2.2123505700477266,0.4430662251338714,0.0602887165933287
"model_curator_model_args_pretrained=gemini__gemini-1.5-flash,tokenized_requests=False,batch_size=1",-2.0652896814524828,0.2909775910083418,0.0988106449800108
"model_curator_model_args_pretrained=gemini__gemini-1.5-flash-8b,tokenized_requests=False,max_requests_per_minute=10000,max_tokens_per_minute=10000000,batch_size=1",-2.3695718018361474,0.3369055355846819,-0.6677889603940012
209 changes: 209 additions & 0 deletions eval/chat_benchmarks/curator_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import json
import os
import time
from typing import Any, Dict, List, Optional, Tuple, Union

from bespokelabs import curator
from datasets import Dataset
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model
from lm_eval.models.api_models import JsonChatStr
from lm_eval.models.utils import handle_stop_sequences


@register_model("curator")
class CuratorAPIModel(TemplateLM):
def __init__(
self,
model: str = None,
pretrained: str = None,
max_length: Optional[int] = 2048,
max_retries: int = 10,
timeout: int = 300,
tokenized_requests: bool = False,
max_requests_per_minute: int = None,
max_tokens_per_minute: int = None,
seconds_to_pause_on_rate_limit: int = None,
**kwargs,
):
super().__init__()

self.model_name = model or pretrained

if "gemini" in self.model_name:
max_requests_per_minute = max_requests_per_minute or 2000
max_tokens_per_minute = max_tokens_per_minute or 4_000_000
elif "claude" in self.model_name:
max_requests_per_minute = max_requests_per_minute or 2000
max_tokens_per_minute = max_tokens_per_minute or 80_000

if tokenized_requests:
raise NotImplementedError("Tokenized requests not implemented for curator.")
self.tokenized_requests = False
self.max_length = max_length
self.llm = None
self.gen_kwargs = {}
self.eos = None
self.backend_params = {
"invalid_finish_reasons": [
"content_filter"
], # So it doesn't retry on `length` finish reason, but retries on "content_filter"}
"require_all_responses": False,
"request_timeout": timeout,
"max_retries": max_retries,
}
if max_requests_per_minute is not None:
self.backend_params["max_requests_per_minute"] = max_requests_per_minute
if max_tokens_per_minute is not None:
self.backend_params["max_tokens_per_minute"] = max_tokens_per_minute
if seconds_to_pause_on_rate_limit is not None:
self.backend_params["seconds_to_pause_on_rate_limit"] = seconds_to_pause_on_rate_limit

# Disable cache since it is not necessary
os.environ["CURATOR_DISABLE_CACHE"] = "true"

def _create_payload(
self,
messages: Union[List[List[int]], List[dict], List[str], str],
*,
generate: bool = False,
gen_kwargs: Optional[dict] = None,
eos=None,
**kwargs,
) -> dict:
assert generate, "Curator only supports generation."
# Create the payload for the API request
max_tokens = gen_kwargs.get("max_gen_toks", self.max_length)
temperature = gen_kwargs.get("temperature", 0)
stop = handle_stop_sequences(gen_kwargs.get("until", None), eos)
gen_kwargs = {
"max_completion_tokens": max_tokens,
"temperature": temperature,
"stop": stop,
}
if self.llm is None:
self.eos = eos
self.gen_kwargs = gen_kwargs.copy()
self.llm = curator.LLM(
model_name=self.model_name, generation_params=gen_kwargs, backend_params=self.backend_params.copy()
)
else:
if self.gen_kwargs != gen_kwargs:
print(
"Recreating curator LLM with new generation parameters, make sure this doesn't happen at every request"
)
self.gen_kwargs = gen_kwargs.copy()
self.llm = curator.LLM(
model_name=self.model_name, generation_params=gen_kwargs, backend_params=self.backend_params.copy()
)
return messages

def create_message(
self, messages: Union[List[List[int]], List[str], List[JsonChatStr]], generate=False
) -> Union[List[List[int]], List[dict], List[str], str]:
# Convert messages to the format expected by the API
if isinstance(messages, list) and all(isinstance(m, JsonChatStr) for m in messages):
return [json.loads(m.prompt) for m in messages]
else:
raise ValueError("Messages must be a list of JsonChatStr objects")

@staticmethod
def parse_logprobs(
outputs: Union[Any, List[Any]], tokens: List[List[int]] = None, ctxlen: List[int] = None, **kwargs
) -> List[Tuple[float, bool]]:
# Implement log probability parsing logic
raise NotImplementedError("Log probability parsing not implemented.")
logprobs = []
for output in outputs:
# Assuming output has a structure that includes log probabilities
logprob = output.get("logprob", 0.0) # Replace with actual key
is_greedy = output.get("is_greedy", False) # Replace with actual key
logprobs.append((logprob, is_greedy))
return logprobs

@staticmethod
def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]:
# Parse the generated outputs from the API
return [output["response"] for output in outputs]

@property
def tokenizer_name(self) -> str:
return self.model_name

def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> Union[str, JsonChatStr]:
# Convert chat history to the required format
return JsonChatStr(json.dumps(chat_history))

def model_call(self, messages: Union[List[List[int]], List[str], List[JsonChatStr]], **kwargs) -> Optional[dict]:
payload = self._create_payload(self.create_message(messages), **kwargs)
response = self.llm(payload)["response"]
return response

def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
raise NotImplementedError("Log likelihood tokens not implemented for curator.")
results = []
for context, continuation in requests:
# Assuming the model can compute log likelihoods
response = self.model_call([context, continuation])
logprob = response.get("logprob", 0.0) # Replace with actual key
is_greedy = response.get("is_greedy", False) # Replace with actual key
results.append((logprob, is_greedy))
return results

@property
def eot_token_id(self) -> Optional[int]:
# Assuming the model has a specific end-of-text token ID
return self.llm.eot_token_id # Replace with actual method to get EOT token ID

def _group_gen_kwargs(self, gen_kwargs: List[dict]) -> List[Tuple[List[int], dict]]:
"""
Group identical generation parameters together and return a list of indices that share the same generation parameters
"""
ordered_set_gen_kwargs = []
for gkw in gen_kwargs:
if gkw not in ordered_set_gen_kwargs:
ordered_set_gen_kwargs.append(gkw)
return [
[i for i, gkw in enumerate(gen_kwargs) if gkw == ordered_set_gen_kwargs[j]]
for j in range(len(ordered_set_gen_kwargs))
], ordered_set_gen_kwargs

def generate_until(self, requests: List[Instance], disable_tqdm: bool = False) -> List[str]:
# Tokenize contexts if required
if self.tokenized_requests:
raise NotImplementedError("Tokenized requests not implemented for curator.")

# Extract contexts and generation kwargs from the Instance objects
contexts = [req.args[0] for req in requests]
gen_kwargs = [req.args[1] for req in requests]

list_indices, gen_kwargs = self._group_gen_kwargs(gen_kwargs)

# Curator needs a new object for each set of gen_kwargs
# Generate responses for each group of generation parameters
response_list = []
for indices, gkw in zip(list_indices, gen_kwargs):
contexts_dataset = self.create_message([contexts[i] for i in indices])
payload = self._create_payload(contexts_dataset, generate=True, gen_kwargs=gkw)
response_list.append(self.llm(payload)["response"])

# Re-order responses to match the original request order
response = [None] * len(requests)
for i, indices in enumerate(list_indices):
for idx in indices:
response[idx] = response_list[i]
return response

def loglikelihood_rolling(self, requests, disable_tqdm: bool = False) -> List[float]:
raise NotImplementedError("Log likelihood rolling not implemented for curator.")
loglikelihoods = []
for context in requests:
response = self.model_call(context)
loglikelihood = response.get("loglikelihood", 0.0) # Replace with actual key
loglikelihoods.append(loglikelihood)
return loglikelihoods

def tok_encode(self, string: str, **kwargs) -> List[int]:
raise NotImplementedError("Token encoding not implemented for curator.")
return self.llm.tokenizer.encode(string) # Replace with actual method to tokenize
1 change: 1 addition & 0 deletions eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import lm_eval.api.task
import lm_eval.models

from eval.chat_benchmarks.curator_lm import CuratorAPIModel # register curator model
from eval.task import TaskManager as InstructTaskManager
from eval.eval_tracker import DCEvaluationTracker

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ dependencies = [
"requests>=2.28",
"websocket",
"aiofiles",
"bespokelabs-curator>=0.16.0",

# Database
"sqlalchemy",
Expand Down

0 comments on commit 3f6a677

Please sign in to comment.