From c7d2dc61ea2ba0a9f25299f99ba2d285592b6bf8 Mon Sep 17 00:00:00 2001 From: Jean Mercat <jean.mercat@tri.global> Date: Thu, 16 Jan 2025 11:52:00 -0800 Subject: [PATCH] wildbench simplify chat input to be compatible with all models, add curator model (doesn't work as is) --- .../WildBench/eval_instruct.py | 7 +- eval/chat_benchmarks/curator_lm.py | 164 ++++++++++++++++++ eval/eval.py | 1 + pyproject.toml | 1 + 4 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 eval/chat_benchmarks/curator_lm.py diff --git a/eval/chat_benchmarks/WildBench/eval_instruct.py b/eval/chat_benchmarks/WildBench/eval_instruct.py index 100e6d68..4d44aa97 100644 --- a/eval/chat_benchmarks/WildBench/eval_instruct.py +++ b/eval/chat_benchmarks/WildBench/eval_instruct.py @@ -166,8 +166,13 @@ def generate_responses(self, model: LM) -> Dict[str, Any]: # Load data id_strs, chat_history, extracted_chats, metadata = self.load_dataset() + # Remove extra fields that might not be compatible with some models + simplified_extracted_chats = [ + [{"role": c["role"], "content": c["content"]} for c in chat] for chat in extracted_chats + ] + # Prepare model inputs - model_inputs = [model.apply_chat_template(chat) for chat in extracted_chats] + model_inputs = [model.apply_chat_template(chat) for chat in simplified_extracted_chats] # Create temporary directory temp_dir_obj = tempfile.TemporaryDirectory() diff --git a/eval/chat_benchmarks/curator_lm.py b/eval/chat_benchmarks/curator_lm.py new file mode 100644 index 00000000..e2dca791 --- /dev/null +++ b/eval/chat_benchmarks/curator_lm.py @@ -0,0 +1,164 @@ +from typing import List, Dict, Any, Optional, Union, Tuple +import json + +from bespokelabs import curator +from datasets import Dataset + +from lm_eval.api.model import TemplateLM +from lm_eval.models.api_models import JsonChatStr +from lm_eval.api.registry import register_model +from lm_eval.api.instance import Instance +from lm_eval.models.utils import handle_stop_sequences + + +class prompter(curator.LLM): + def prompt(self, row): + return row["messages"] + + def parse(self, row, response): + return {"response": response} + + +@register_model("curator") +class CuratorAPIModel(TemplateLM): + def __init__( + self, + model: str = None, + pretrained: str = None, + max_length: Optional[int] = 2048, + num_concurrent: int = 1, + max_retries: int = 3, + timeout: int = 300, + tokenized_requests: bool = False, + **kwargs, + ): + super().__init__() + if tokenized_requests: + raise NotImplementedError("Tokenized requests not implemented for curator.") + self.tokenized_requests = False + self.model_name = model or pretrained + self.max_length = max_length + self.num_concurrent = num_concurrent + self.max_retries = max_retries + self.timeout = timeout + self.llm = None + self.gen_kwargs = {} + self._max_gen_toks = 2048 + self.eos = None + + def _create_payload( + self, + messages: Union[List[List[int]], List[dict], List[str], str], + *, + generate: bool = False, + gen_kwargs: Optional[dict] = None, + eos=None, + **kwargs, + ) -> dict: + assert generate, "Curator only supports generation." + # Create the payload for the API request + if self.llm is None: + self.gen_kwargs = gen_kwargs.copy() + self.eos = eos + max_tokens = gen_kwargs.get("max_gen_toks", self._max_gen_toks) + temperature = gen_kwargs.get("temperature", 0) + stop = handle_stop_sequences(gen_kwargs.get("until", None), eos) + gen_kwargs = { + "max_completion_tokens": max_tokens, + "temperature": temperature, + "stop": stop, + } + self.llm = prompter(model_name=self.model_name, generation_params=gen_kwargs) + else: + assert self.gen_kwargs == gen_kwargs, "Generation parameters must be the same for all requests in curator" + assert self.eos == eos, "EOS must be the same for all requests in curator" + return messages + + def create_message( + self, messages: Union[List[List[int]], List[str], List[JsonChatStr]], generate=False + ) -> Union[List[List[int]], List[dict], List[str], str]: + # Convert messages to the format expected by the API + if isinstance(messages, list) and all(isinstance(m, JsonChatStr) for m in messages): + return [Dataset.from_dict({"messages": json.loads(m.prompt)}) for m in messages] + return messages + + @staticmethod + def parse_logprobs( + outputs: Union[Any, List[Any]], tokens: List[List[int]] = None, ctxlen: List[int] = None, **kwargs + ) -> List[Tuple[float, bool]]: + # Implement log probability parsing logic + raise NotImplementedError("Log probability parsing not implemented.") + logprobs = [] + for output in outputs: + # Assuming output has a structure that includes log probabilities + logprob = output.get("logprob", 0.0) # Replace with actual key + is_greedy = output.get("is_greedy", False) # Replace with actual key + logprobs.append((logprob, is_greedy)) + return logprobs + + @staticmethod + def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]: + # Parse the generated outputs from the API + return [output["response"] for output in outputs] + + @property + def tokenizer_name(self) -> str: + return self.model_name + + def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> Union[str, JsonChatStr]: + # Convert chat history to the required format + return JsonChatStr(json.dumps(chat_history)) + + def model_call(self, messages: Union[List[List[int]], List[str], List[JsonChatStr]], **kwargs) -> Optional[dict]: + payload = self._create_payload(self.create_message(messages), **kwargs) + response = self.llm(payload) + return response + + def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: + raise NotImplementedError("Log likelihood tokens not implemented for curator.") + results = [] + for context, continuation in requests: + # Assuming the model can compute log likelihoods + response = self.model_call([context, continuation]) + logprob = response.get("logprob", 0.0) # Replace with actual key + is_greedy = response.get("is_greedy", False) # Replace with actual key + results.append((logprob, is_greedy)) + return results + + @property + def eot_token_id(self) -> Optional[int]: + # Assuming the model has a specific end-of-text token ID + return self.llm.eot_token_id # Replace with actual method to get EOT token ID + + def generate_until(self, requests: List[Instance], disable_tqdm: bool = False) -> List[str]: + # Tokenize contexts if required + if self.tokenized_requests: + raise NotImplementedError("Tokenized requests not implemented for curator.") + + # Extract contexts and generation kwargs from the Instance objects + contexts = [req.args[0] for req in requests] + gen_kwargs = [req.args[1] for req in requests] + + # Assert all gen_kwargs are the same + assert all( + gen_kwargs[0] == gkw for gkw in gen_kwargs + ), "Generation parameters must be the same for all requests in curator" + + payload = self._create_payload(self.create_message(contexts), generate=True, gen_kwargs=gen_kwargs[0]) + breakpoint() + response = self.llm(payload) + breakpoint() + return response + + def loglikelihood_rolling(self, requests, disable_tqdm: bool = False) -> List[float]: + raise NotImplementedError("Log likelihood rolling not implemented for curator.") + loglikelihoods = [] + for context in requests: + response = self.model_call(context) + loglikelihood = response.get("loglikelihood", 0.0) # Replace with actual key + loglikelihoods.append(loglikelihood) + return loglikelihoods + + def tok_encode(self, string: str, **kwargs) -> List[int]: + raise NotImplementedError("Token encoding not implemented for curator.") + return self.llm.tokenizer.encode(string) # Replace with actual method to tokenize diff --git a/eval/eval.py b/eval/eval.py index ca3deb31..9816134c 100644 --- a/eval/eval.py +++ b/eval/eval.py @@ -23,6 +23,7 @@ import lm_eval.api.task import lm_eval.models +from eval.chat_benchmarks.curator_lm import CuratorAPIModel # register curator model from eval.task import TaskManager as InstructTaskManager from eval.eval_tracker import DCEvaluationTracker diff --git a/pyproject.toml b/pyproject.toml index 824c1156..294246c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,7 @@ dependencies = [ "requests>=2.28", "websocket", "aiofiles", + "bespokelabs-curator", # Database "sqlalchemy",