Skip to content

Commit

Permalink
Merge pull request #71 from jmercat/jean/main_curator
Browse files Browse the repository at this point in the history
Hot fix in curator generate
  • Loading branch information
neginraoof authored Jan 30, 2025
2 parents f0c35ab + 8ca479c commit 3eb7b44
Showing 1 changed file with 25 additions and 31 deletions.
56 changes: 25 additions & 31 deletions eval/chat_benchmarks/curator_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ def __init__(

self.model_name = model or pretrained

if "gemini" in self.model_name:
if "gemini" in self.model_name and "thinking" in self.model_name:
max_requests_per_minute = max_requests_per_minute or 200
max_tokens_per_minute = max_tokens_per_minute or 400_000
elif "gemini" in self.model_name:
max_requests_per_minute = max_requests_per_minute or 2000
max_tokens_per_minute = max_tokens_per_minute or 4_000_000
elif "claude" in self.model_name:
Expand All @@ -45,6 +48,10 @@ def __init__(
self.llm = None
self.gen_kwargs = {}
self.eos = None
if "temperature" in kwargs:
self.gen_kwargs["temperature"] = kwargs["temperature"]
if "top_p" in kwargs:
self.gen_kwargs["top_p"] = kwargs["top_p"]
self.backend_params = {
"invalid_finish_reasons": [
"content_filter"
Expand Down Expand Up @@ -74,14 +81,21 @@ def _create_payload(
) -> dict:
assert generate, "Curator only supports generation."
# Create the payload for the API request
max_tokens = gen_kwargs.get("max_gen_toks", self.max_length)
temperature = gen_kwargs.get("temperature", 0)
max_tokens = self.max_length or gen_kwargs.get("max_gen_toks", self.max_length)
temperature = self.gen_kwargs.get("temperature", gen_kwargs.get("temperature", 0))
top_p = self.gen_kwargs.get("top_p", gen_kwargs.get("top_p", 0.95))
stop = handle_stop_sequences(gen_kwargs.get("until", None), eos)
gen_kwargs = {
"max_completion_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
"stop": stop,
}
if "o1" in self.model_name:
print("Warning: O1 model does not support top_p, stop, or temperature. Ignoring them.")
gen_kwargs.pop("top_p")
gen_kwargs.pop("stop")
gen_kwargs.pop("temperature")
if self.llm is None:
self.eos = eos
self.gen_kwargs = gen_kwargs.copy()
Expand Down Expand Up @@ -156,19 +170,6 @@ def eot_token_id(self) -> Optional[int]:
# Assuming the model has a specific end-of-text token ID
return self.llm.eot_token_id # Replace with actual method to get EOT token ID

def _group_gen_kwargs(self, gen_kwargs: List[dict]) -> List[Tuple[List[int], dict]]:
"""
Group identical generation parameters together and return a list of indices that share the same generation parameters
"""
ordered_set_gen_kwargs = []
for gkw in gen_kwargs:
if gkw not in ordered_set_gen_kwargs:
ordered_set_gen_kwargs.append(gkw)
return [
[i for i, gkw in enumerate(gen_kwargs) if gkw == ordered_set_gen_kwargs[j]]
for j in range(len(ordered_set_gen_kwargs))
], ordered_set_gen_kwargs

def generate_until(self, requests: List[Instance], disable_tqdm: bool = False) -> List[str]:
# Tokenize contexts if required
if self.tokenized_requests:
Expand All @@ -178,21 +179,14 @@ def generate_until(self, requests: List[Instance], disable_tqdm: bool = False) -
contexts = [req.args[0] for req in requests]
gen_kwargs = [req.args[1] for req in requests]

list_indices, gen_kwargs = self._group_gen_kwargs(gen_kwargs)

# Curator needs a new object for each set of gen_kwargs
# Generate responses for each group of generation parameters
response_list = []
for indices, gkw in zip(list_indices, gen_kwargs):
contexts_dataset = self.create_message([contexts[i] for i in indices])
payload = self._create_payload(contexts_dataset, generate=True, gen_kwargs=gkw)
response_list.append(self.llm(payload)["response"])

# Re-order responses to match the original request order
response = [None] * len(requests)
for i, indices in enumerate(list_indices):
for idx in indices:
response[idx] = response_list[i]
# Assert all gen_kwargs are the same
assert all(
gen_kwargs[0] == gkw for gkw in gen_kwargs
), "Generation parameters must be the same for all requests in curator"

contexts_dataset = self.create_message(contexts)
payload = self._create_payload(contexts_dataset, generate=True, gen_kwargs=gen_kwargs[0])
response = self.llm(payload)["response"]
return response

def loglikelihood_rolling(self, requests, disable_tqdm: bool = False) -> List[float]:
Expand Down

0 comments on commit 3eb7b44

Please sign in to comment.