Skip to content

Commit

Permalink
Improving Transform and Rerank Module (#396)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritugala authored Apr 19, 2024
1 parent 04debc9 commit 6709095
Show file tree
Hide file tree
Showing 20 changed files with 920 additions and 790 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ huggingface_data/huggingface_datasets/huggingface_datasets_datafinder_index
huggingface_data/huggingface_datasets/reranking_dataset_index.json
huggingface_data/huggingface_models/
retrieved_dataset_dict/
result/
checkpoint/
status.yaml

# Outputs generated by the colab demo
trained_model/
trained_tokenizer/
9 changes: 5 additions & 4 deletions examples/create_transform_data_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@

# run this pipeline to retrieve relevant datasets, rerank them,
# and transform them based on the prompt
retriever = DescriptionDatasetRetriever()
num_points_to_transform = 20
total_num_points_to_transform = 20
retriever = DescriptionDatasetRetriever(
auto_transform_data=True,
total_num_points_to_transform=total_num_points_to_transform,
)
retrieved_dataset_dict = retriever.retrieve_dataset_dict(
prompt_spec,
auto_transform_data=True,
num_points_to_transform=num_points_to_transform,
)

# save the final dataset to disk
Expand Down

This file was deleted.

This file was deleted.

Binary file not shown.
297 changes: 233 additions & 64 deletions prompt2model/dataset_retriever/description_dataset_retriever.py

Large diffs are not rendered by default.

205 changes: 78 additions & 127 deletions prompt2model/dataset_retriever/reranking_prompt.py

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions prompt2model/dataset_retriever/task_expansion_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""This module contains the functions to construct the prompt for task expansion."""
METAPROMPT_BASE = "Carefully analyse the task description and examples of the task, and explain the task to give a clearer description. Do not explain each example, but rather capture the general trends. Also place special focus on the format of the input/output examples." # noqa: E501

TASK = """
Task Description: {task_description}
Task Examples: {examples}
"""


def construct_prompt_for_task_explanation(instruction: str, demonstrations: str):
"""Constructs prompt for task explanation.
This is useful for clarifying the requirements of a task,
and providing a clearer description of the task.
Args:
instruction (str): The task instruction.
demonstrations (str): The task demonstrations.
Returns:
str: The constructed prompt.
"""
task = TASK.format(task_description=instruction, examples=demonstrations)
prompt = "\n--------\n".join([METAPROMPT_BASE, task])
return prompt
1 change: 0 additions & 1 deletion prompt2model/dataset_transformer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def transform_data(
self,
prompt_spec: PromptSpec,
dataset: datasets.Dataset,
num_points_to_transform: int,
) -> datasets.Dataset:
"""Transform a split of data.
Expand Down
257 changes: 169 additions & 88 deletions prompt2model/dataset_transformer/prompt_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import datasets

from prompt2model.dataset_retriever.task_expansion_prompt import (
construct_prompt_for_task_explanation,
)
from prompt2model.dataset_transformer.base import DatasetTransformer
from prompt2model.dataset_transformer.prompt_template import (
construct_prompt_for_plan,
Expand All @@ -31,122 +34,200 @@ class PromptBasedDatasetTransformer(DatasetTransformer):

def __init__(
self,
num_points_to_transform: int = 10,
max_allowed_failed_transforms: int = 3,
plan_prompt_fn: Callable[
[str, str, list[dict], int], str
[str, str, list[dict]], str
] = construct_prompt_for_plan,
transform_prompt_fn: Callable[
[str, str, str, dict], str
[str, str, str, str], str
] = construct_prompt_for_transform_data,
num_retries: int = 10,
):
"""Initialize the class.
"""Initializes an instance of the PromptBasedDatasetTransformer class.
Args:
plan_prompt_fn: A function that takes in a description of the target task,
example of the target task,
list of dictionaries where each dictionary is a row from a potentially
relevant dataset,
and the number of rows to use from this potentially relevant dataset,
and returns a plan prompt.
transform_prompt_fn: A function that takes in a description of the target
task, an example of the target task,
plan for dataset transformation,
and the row from a potentially relevant dataset to be transformed.
num_points_to_transform: The number of points to transform.
max_allowed_failed_transforms: The maximum number of
failed transforms allowed.
plan_prompt_fn: The function to construct the prompt for plan
transform_prompt_fn: The function to construct the prompt
for transform data.
num_retries: The number of retries to attempt for each API call.
"""
self.plan_prompt_fn = plan_prompt_fn
self.transform_prompt_fn = transform_prompt_fn
self.plan: str = ""

def make_dataset_from_samples(
self,
inputs: list[str],
outputs: list[str],
) -> datasets.DatasetDict:
"""Given a list of inputs and outputs, make a dataset.
This function takes in inputs and outputs, both as list of strings,
and returns a DatasetDict object with a single split, "train". It has
two columns, "input_col" and "output_col".
Args:
inputs: A list of inputs, each input is a string.
outputs: A list of outputs, each output is a string.
Returns:
A DatasetDict object with a single split, "train". It has two
columns, "input_col" and "output_col".
"""
if len(inputs) <= 0 or len(inputs) != len(outputs):
raise ValueError("Length of inputs and outputs must be >0 and equal.")

dataset_dict = {}
dataset_dict["train"] = datasets.Dataset.from_dict(
{"input_col": inputs, "output_col": outputs}
self.num_points_to_transform = num_points_to_transform
self.curr_failed_transforms = 0
self.max_allowed_failed_transforms = max_allowed_failed_transforms
self.num_retries = num_retries

def generate_task_explanation(self, prompt_spec: PromptSpec) -> str:
"""Generate task explanation."""
task_explanation_prompt = construct_prompt_for_task_explanation(
prompt_spec.instruction, prompt_spec.examples
)
return make_single_api_request(
task_explanation_prompt, max_api_calls=self.num_retries
)
return datasets.DatasetDict(dataset_dict)

def transform_data(
self,
prompt_spec: PromptSpec,
dataset: datasets.Dataset,
num_points_to_transform: int,
) -> datasets.DatasetDict:
"""Transform the dataset according to the prompt_spec and dataset."""
def generate_plan(
self, task_explanation: str, dataset: datasets.Dataset, prompt_spec: PromptSpec
) -> str:
"""Generate plan for the task."""
plan_prompt = self.plan_prompt_fn(
prompt_spec.instruction,
prompt_spec.examples,
dataset,
min(5, len(dataset)),
task_explanation, prompt_spec.examples, dataset
)
self.plan = make_single_api_request(plan_prompt)

logger.info(f"Plan created. Plan: {self.plan}")

inputs = []
outputs = []
return make_single_api_request(plan_prompt, max_api_calls=self.num_retries)

max_len = min(num_points_to_transform, len(dataset))
len_count = 0
def generate_transform_prompts(
self,
task_explanation: str,
dataset: datasets.Dataset,
prompt_spec: PromptSpec,
) -> list[str]:
"""Get transform prompts for each row in the dataset."""
transform_prompts = []
for row in dataset:
for i in range(min(self.num_points_to_transform, len(dataset))):
row = dataset[i]
transform_prompt = self.transform_prompt_fn(
prompt_spec.instruction,
prompt_spec.examples,
self.plan,
row,
task_explanation, row, self.plan, prompt_spec.examples
)
transform_prompts.append(transform_prompt)
return transform_prompts

len_count += 1
if len_count >= max_len:
break
def generate_responses(
self, transform_prompts_batch: list[str], model_name="gpt-3.5-turbo"
) -> list[str]:
"""Generate responses for the given transform prompts.
async def generate_responses(transform_prompts):
responses = await api_tools.default_api_agent.generate_batch_completion(
transform_prompts,
temperature=0,
responses_per_request=1,
requests_per_minute=15,
)
return responses
Args:
transform_prompts_batch: A list of transform prompts.
model_name: The name of the model to use. Defaults to
"gpt-3.5-turbo" to save costs.
try:
loop = asyncio.get_event_loop()
responses = loop.run_until_complete(generate_responses(transform_prompts))
except API_ERRORS as e:
handle_api_error(e)
Returns:
A list of generated responses.
Raises:
API_ERRORS: If there is an error with the API.
"""
api_call_counter = 0
last_error = None
responses = []
while True:
api_call_counter += 1

async def generate_responses_async(transform_prompts):
"""Generate responses asynchronously using the specified model."""
responses = await api_tools.APIAgent(
model_name=model_name
).generate_batch_completion(
transform_prompts,
temperature=0,
responses_per_request=1,
requests_per_minute=15,
)
return responses

try:
loop = asyncio.get_event_loop()
responses = loop.run_until_complete(
generate_responses_async(transform_prompts_batch)
)
break
except API_ERRORS as e:
last_error = e
handle_api_error(e)
if api_call_counter > self.num_retries:
# In case we reach maximum number of API calls, we raise an error.
logger.error("Maximum number of API calls reached.")
raise RuntimeError(
"Maximum number of API calls reached."
) from last_error

return responses

def process_responses(
self, responses: list, prompt_spec: PromptSpec
) -> tuple[list[str], list[str]]:
"""Process the responses received from the API.
Args:
responses: A list of response strings from the API.
prompt_spec: The PromptSpec object containing the instruction and examples.
Returns:
A tuple containing two lists: inputs and outputs.
- inputs: A list of transformed input strings.
- outputs: A list of transformed output strings.
"""
inputs, outputs = [], []
show_sample_flag = False
for response in responses:
try:
extraction = find_and_parse_json(response, ["input", "output"], [])
if extraction is not None:
inputs.append(str(extraction["input"]))
outputs.append(str(extraction["output"]))
if extraction["input"] is None or extraction["output"] is None:
raise ValueError("Input or output is None")
input = str(extraction["input"]).strip()
output = str(extraction["output"]).strip()
if input in prompt_spec.examples:
raise ValueError("Repeated Task Examples from prompt")

inputs.append(input)
outputs.append(output)
if show_sample_flag:
logger.info(f"inputs\n{input}\n\nouputs\n{output}")
show_sample_flag = False

except Exception as e:
logger.error(f"Error extracting from response: {response}\nError: {e}")
continue
logger.error(f"Error extracting from response: {e}")
self.curr_failed_transforms += 1
if self.curr_failed_transforms > self.max_allowed_failed_transforms:
break

logger.info(f"Requested length: {max_len}\nActual length: {len(inputs)}\n")
return inputs, outputs

return self.make_dataset_from_samples(inputs, outputs)
def transform_data(
self, prompt_spec: PromptSpec, dataset: datasets.Dataset
) -> tuple[list[str], list[str]]:
"""Transforms the given dataset based on the provided prompt specification.
Args:
prompt_spec: The prompt specification object that defines
the transformation rules.
dataset: The dataset to be transformed.
Returns:
A tuple containing two lists: inputs and outputs.
"""
task_explanation = self.generate_task_explanation(prompt_spec)
self.plan = self.generate_plan(task_explanation, dataset, prompt_spec)
logger.info(f"Plan created. Plan: {self.plan}")

transform_prompts = self.generate_transform_prompts(
task_explanation, dataset, prompt_spec
)
inputs, outputs = [], []
for batch_indices in range(0, len(transform_prompts), 100):
transform_prompt_batch = transform_prompts[
batch_indices : batch_indices + 100
]
responses = self.generate_responses(transform_prompt_batch)
curr_inputs, curr_outputs = self.process_responses(responses, prompt_spec)
inputs.extend(curr_inputs)
outputs.extend(curr_outputs)
if self.curr_failed_transforms > self.max_allowed_failed_transforms:
logger.error(
f"Exceeded max allowed failed transforms: {self.curr_failed_transforms}" # noqa: E501
)
self.max_allowed_failed_transforms = 0
break

logger.info(
f"Requested length: {self.num_points_to_transform}\nActual length: {len(inputs)}\n" # noqa: E501
)
return inputs, outputs
Loading

0 comments on commit 6709095

Please sign in to comment.