From c7d2dc61ea2ba0a9f25299f99ba2d285592b6bf8 Mon Sep 17 00:00:00 2001
From: Jean Mercat <jean.mercat@tri.global>
Date: Thu, 16 Jan 2025 11:52:00 -0800
Subject: [PATCH] wildbench simplify chat input to be compatible with all
 models, add curator model (doesn't work as is)

---
 .../WildBench/eval_instruct.py                |   7 +-
 eval/chat_benchmarks/curator_lm.py            | 164 ++++++++++++++++++
 eval/eval.py                                  |   1 +
 pyproject.toml                                |   1 +
 4 files changed, 172 insertions(+), 1 deletion(-)
 create mode 100644 eval/chat_benchmarks/curator_lm.py

diff --git a/eval/chat_benchmarks/WildBench/eval_instruct.py b/eval/chat_benchmarks/WildBench/eval_instruct.py
index 100e6d68..4d44aa97 100644
--- a/eval/chat_benchmarks/WildBench/eval_instruct.py
+++ b/eval/chat_benchmarks/WildBench/eval_instruct.py
@@ -166,8 +166,13 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
             # Load data
             id_strs, chat_history, extracted_chats, metadata = self.load_dataset()
 
+            # Remove extra fields that might not be compatible with some models
+            simplified_extracted_chats = [
+                [{"role": c["role"], "content": c["content"]} for c in chat] for chat in extracted_chats
+            ]
+
             # Prepare model inputs
-            model_inputs = [model.apply_chat_template(chat) for chat in extracted_chats]
+            model_inputs = [model.apply_chat_template(chat) for chat in simplified_extracted_chats]
 
             # Create temporary directory
             temp_dir_obj = tempfile.TemporaryDirectory()
diff --git a/eval/chat_benchmarks/curator_lm.py b/eval/chat_benchmarks/curator_lm.py
new file mode 100644
index 00000000..e2dca791
--- /dev/null
+++ b/eval/chat_benchmarks/curator_lm.py
@@ -0,0 +1,164 @@
+from typing import List, Dict, Any, Optional, Union, Tuple
+import json
+
+from bespokelabs import curator
+from datasets import Dataset
+
+from lm_eval.api.model import TemplateLM
+from lm_eval.models.api_models import JsonChatStr
+from lm_eval.api.registry import register_model
+from lm_eval.api.instance import Instance
+from lm_eval.models.utils import handle_stop_sequences
+
+
+class prompter(curator.LLM):
+    def prompt(self, row):
+        return row["messages"]
+
+    def parse(self, row, response):
+        return {"response": response}
+
+
+@register_model("curator")
+class CuratorAPIModel(TemplateLM):
+    def __init__(
+        self,
+        model: str = None,
+        pretrained: str = None,
+        max_length: Optional[int] = 2048,
+        num_concurrent: int = 1,
+        max_retries: int = 3,
+        timeout: int = 300,
+        tokenized_requests: bool = False,
+        **kwargs,
+    ):
+        super().__init__()
+        if tokenized_requests:
+            raise NotImplementedError("Tokenized requests not implemented for curator.")
+        self.tokenized_requests = False
+        self.model_name = model or pretrained
+        self.max_length = max_length
+        self.num_concurrent = num_concurrent
+        self.max_retries = max_retries
+        self.timeout = timeout
+        self.llm = None
+        self.gen_kwargs = {}
+        self._max_gen_toks = 2048
+        self.eos = None
+
+    def _create_payload(
+        self,
+        messages: Union[List[List[int]], List[dict], List[str], str],
+        *,
+        generate: bool = False,
+        gen_kwargs: Optional[dict] = None,
+        eos=None,
+        **kwargs,
+    ) -> dict:
+        assert generate, "Curator only supports generation."
+        # Create the payload for the API request
+        if self.llm is None:
+            self.gen_kwargs = gen_kwargs.copy()
+            self.eos = eos
+            max_tokens = gen_kwargs.get("max_gen_toks", self._max_gen_toks)
+            temperature = gen_kwargs.get("temperature", 0)
+            stop = handle_stop_sequences(gen_kwargs.get("until", None), eos)
+            gen_kwargs = {
+                "max_completion_tokens": max_tokens,
+                "temperature": temperature,
+                "stop": stop,
+            }
+            self.llm = prompter(model_name=self.model_name, generation_params=gen_kwargs)
+        else:
+            assert self.gen_kwargs == gen_kwargs, "Generation parameters must be the same for all requests in curator"
+            assert self.eos == eos, "EOS must be the same for all requests in curator"
+        return messages
+
+    def create_message(
+        self, messages: Union[List[List[int]], List[str], List[JsonChatStr]], generate=False
+    ) -> Union[List[List[int]], List[dict], List[str], str]:
+        # Convert messages to the format expected by the API
+        if isinstance(messages, list) and all(isinstance(m, JsonChatStr) for m in messages):
+            return [Dataset.from_dict({"messages": json.loads(m.prompt)}) for m in messages]
+        return messages
+
+    @staticmethod
+    def parse_logprobs(
+        outputs: Union[Any, List[Any]], tokens: List[List[int]] = None, ctxlen: List[int] = None, **kwargs
+    ) -> List[Tuple[float, bool]]:
+        # Implement log probability parsing logic
+        raise NotImplementedError("Log probability parsing not implemented.")
+        logprobs = []
+        for output in outputs:
+            # Assuming output has a structure that includes log probabilities
+            logprob = output.get("logprob", 0.0)  # Replace with actual key
+            is_greedy = output.get("is_greedy", False)  # Replace with actual key
+            logprobs.append((logprob, is_greedy))
+        return logprobs
+
+    @staticmethod
+    def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]:
+        # Parse the generated outputs from the API
+        return [output["response"] for output in outputs]
+
+    @property
+    def tokenizer_name(self) -> str:
+        return self.model_name
+
+    def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> Union[str, JsonChatStr]:
+        # Convert chat history to the required format
+        return JsonChatStr(json.dumps(chat_history))
+
+    def model_call(self, messages: Union[List[List[int]], List[str], List[JsonChatStr]], **kwargs) -> Optional[dict]:
+        payload = self._create_payload(self.create_message(messages), **kwargs)
+        response = self.llm(payload)
+        return response
+
+    def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
+        raise NotImplementedError("Log likelihood tokens not implemented for curator.")
+        results = []
+        for context, continuation in requests:
+            # Assuming the model can compute log likelihoods
+            response = self.model_call([context, continuation])
+            logprob = response.get("logprob", 0.0)  # Replace with actual key
+            is_greedy = response.get("is_greedy", False)  # Replace with actual key
+            results.append((logprob, is_greedy))
+        return results
+
+    @property
+    def eot_token_id(self) -> Optional[int]:
+        # Assuming the model has a specific end-of-text token ID
+        return self.llm.eot_token_id  # Replace with actual method to get EOT token ID
+
+    def generate_until(self, requests: List[Instance], disable_tqdm: bool = False) -> List[str]:
+        # Tokenize contexts if required
+        if self.tokenized_requests:
+            raise NotImplementedError("Tokenized requests not implemented for curator.")
+
+        # Extract contexts and generation kwargs from the Instance objects
+        contexts = [req.args[0] for req in requests]
+        gen_kwargs = [req.args[1] for req in requests]
+
+        # Assert all gen_kwargs are the same
+        assert all(
+            gen_kwargs[0] == gkw for gkw in gen_kwargs
+        ), "Generation parameters must be the same for all requests in curator"
+
+        payload = self._create_payload(self.create_message(contexts), generate=True, gen_kwargs=gen_kwargs[0])
+        breakpoint()
+        response = self.llm(payload)
+        breakpoint()
+        return response
+
+    def loglikelihood_rolling(self, requests, disable_tqdm: bool = False) -> List[float]:
+        raise NotImplementedError("Log likelihood rolling not implemented for curator.")
+        loglikelihoods = []
+        for context in requests:
+            response = self.model_call(context)
+            loglikelihood = response.get("loglikelihood", 0.0)  # Replace with actual key
+            loglikelihoods.append(loglikelihood)
+        return loglikelihoods
+
+    def tok_encode(self, string: str, **kwargs) -> List[int]:
+        raise NotImplementedError("Token encoding not implemented for curator.")
+        return self.llm.tokenizer.encode(string)  # Replace with actual method to tokenize
diff --git a/eval/eval.py b/eval/eval.py
index ca3deb31..9816134c 100644
--- a/eval/eval.py
+++ b/eval/eval.py
@@ -23,6 +23,7 @@
 import lm_eval.api.task
 import lm_eval.models
 
+from eval.chat_benchmarks.curator_lm import CuratorAPIModel  # register curator model
 from eval.task import TaskManager as InstructTaskManager
 from eval.eval_tracker import DCEvaluationTracker
 
diff --git a/pyproject.toml b/pyproject.toml
index 824c1156..294246c1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -104,6 +104,7 @@ dependencies = [
     "requests>=2.28",
     "websocket",
     "aiofiles",
+    "bespokelabs-curator",
     
     # Database
     "sqlalchemy",