Skip to content

Commit

Permalink
format: ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
Roni-Friedman committed Oct 28, 2024
1 parent c22397b commit 64aff53
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 37 deletions.
4 changes: 3 additions & 1 deletion src/instructlab/eval/mmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def run(self, server_url: str | None = None) -> tuple:

return overall_score, individual_scores

def _run_mmlu(self, server_url: str | None = None, return_all_results:bool = False) -> dict:
def _run_mmlu(
self, server_url: str | None = None, return_all_results: bool = False
) -> dict:
if server_url is not None:
# Requires lm_eval >= 0.4.4
model_args = f"base_url={server_url}/completions,model={self.model_path},tokenizer_backend=huggingface"
Expand Down
71 changes: 39 additions & 32 deletions src/instructlab/eval/unitxt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
"""

# Standard
import os, shutil
import yaml
from uuid import uuid4
import os
import shutil

# Third Party
from lm_eval.tasks.unitxt import task
import yaml

# First Party
from instructlab.eval.mmlu import MMLUBranchEvaluator
Expand All @@ -20,7 +21,8 @@

logger = setup_logger(__name__)

TEMP_DIR_PREFIX = 'unitxt_temp'
TEMP_DIR_PREFIX = "unitxt_temp"


class UnitxtEvaluator(MMLUBranchEvaluator):
"""
Expand All @@ -29,45 +31,50 @@ class UnitxtEvaluator(MMLUBranchEvaluator):
Attributes:
model_path absolute path to or name of a huggingface model
unitxt_recipe unitxt recipe (see unitxt.ai for more information)
A Recipe holds a complete specification of a unitxt pipeline
A Recipe holds a complete specification of a unitxt pipeline
Example: card=cards.wnli,template=templates.classification.multi_class.relation.default,max_train_instances=5,loader_limit=20,num_demos=3,demos_pool_size=10
"""

name = "unitxt"

def __init__(
self,
model_path,
model_path,
unitxt_recipe: str,
):
task = self.assign_task_name()

Check warning on line 46 in src/instructlab/eval/unitxt.py

View workflow job for this annotation

GitHub Actions / pylint

W0621: Redefining name 'task' from outer scope (line 13) (redefined-outer-name)
tasks_dir = self.assign_tasks_dir(task)
super().__init__(
model_path = model_path,
tasks_dir = tasks_dir,
tasks = [task],
few_shots = 0
model_path=model_path, tasks_dir=tasks_dir, tasks=[task], few_shots=0
)
self.unitxt_recipe = unitxt_recipe

def assign_tasks_dir(self, task):

Check warning on line 53 in src/instructlab/eval/unitxt.py

View workflow job for this annotation

GitHub Actions / pylint

W0621: Redefining name 'task' from outer scope (line 13) (redefined-outer-name)
return f'{TEMP_DIR_PREFIX}_{task}'
return f"{TEMP_DIR_PREFIX}_{task}"

def assign_task_name(self):
return str(uuid4())

def prepare_unitxt_files(self)->tuple:
def prepare_unitxt_files(self) -> tuple:
task = self.tasks[0]

Check warning on line 60 in src/instructlab/eval/unitxt.py

View workflow job for this annotation

GitHub Actions / pylint

W0621: Redefining name 'task' from outer scope (line 13) (redefined-outer-name)
yaml_file = os.path.join(self.tasks_dir,f"{task}.yaml")
yaml_file = os.path.join(self.tasks_dir, f"{task}.yaml")
create_unitxt_pointer(self.tasks_dir)
create_unitxt_yaml(yaml_file=yaml_file, unitxt_recipe=self.unitxt_recipe, task_name=task)
create_unitxt_yaml(
yaml_file=yaml_file, unitxt_recipe=self.unitxt_recipe, task_name=task
)

def remove_unitxt_files(self):
if self.tasks_dir.startswith(TEMP_DIR_PREFIX): #to avoid unintended deletion if this class is inherited
if self.tasks_dir.startswith(
TEMP_DIR_PREFIX
): # to avoid unintended deletion if this class is inherited
shutil.rmtree(self.tasks_dir)
else:
logger.warning(f"unitxt tasks dir did not start with '{TEMP_DIR_PREFIX}' and therefor was not deleted")
logger.warning(

Check warning on line 73 in src/instructlab/eval/unitxt.py

View workflow job for this annotation

GitHub Actions / pylint

W1203: Use lazy % or .format() formatting in logging functions (logging-fstring-interpolation)
f"unitxt tasks dir did not start with '{TEMP_DIR_PREFIX}' and therefor was not deleted"
)

def run(self,server_url: str | None = None) -> tuple:
def run(self, server_url: str | None = None) -> tuple:
"""
Runs evaluation
Expand All @@ -80,13 +87,16 @@ def run(self,server_url: str | None = None) -> tuple:
os.environ["TOKENIZERS_PARALLELISM"] = "true"
results = self._run_mmlu(server_url=server_url, return_all_results=True)
taskname = self.tasks[0]
global_scores = results['results'][taskname]
global_scores.pop('alias')
global_scores = results["results"][taskname]
global_scores.pop("alias")
try:
instances = results['samples'][taskname]
instances = results["samples"][taskname]
instance_scores = {}
metrics = [metric.replace('metrics.','') for metric in instances[0]['doc']['metrics']]
for i,instance in enumerate(instances):
metrics = [
metric.replace("metrics.", "")
for metric in instances[0]["doc"]["metrics"]
]
for i, instance in enumerate(instances):
scores = {}
for metric in metrics:
scores[metric] = instance[metric][0]
Expand All @@ -97,23 +107,20 @@ def run(self,server_url: str | None = None) -> tuple:
logger.error(e.__traceback__)
instance_scores = None
self.remove_unitxt_files()
return global_scores,instance_scores
return global_scores, instance_scores


def create_unitxt_yaml(yaml_file,unitxt_recipe, task_name):
data = {
'task': f'{task_name}',
'include': 'unitxt',
'recipe': f'{unitxt_recipe}'
}
with open(yaml_file, 'w') as file:
def create_unitxt_yaml(yaml_file, unitxt_recipe, task_name):
data = {"task": f"{task_name}", "include": "unitxt", "recipe": f"{unitxt_recipe}"}
with open(yaml_file, "w") as file:

Check warning on line 115 in src/instructlab/eval/unitxt.py

View workflow job for this annotation

GitHub Actions / pylint

W1514: Using open without explicitly specifying an encoding (unspecified-encoding)
yaml.dump(data, file, default_flow_style=False)
logger.debug(f"task {task} unitxt recipe written to {yaml_file}")

Check warning on line 117 in src/instructlab/eval/unitxt.py

View workflow job for this annotation

GitHub Actions / pylint

W1203: Use lazy % or .format() formatting in logging functions (logging-fstring-interpolation)


def create_unitxt_pointer(tasks_dir):
class_line = "class: !function " + task.__file__.replace("task.py", "task.Unitxt")
output_file = os.path.join(tasks_dir,'unitxt')
output_file = os.path.join(tasks_dir, "unitxt")
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'w') as f:
with open(output_file, "w") as f:

Check warning on line 124 in src/instructlab/eval/unitxt.py

View workflow job for this annotation

GitHub Actions / pylint

W1514: Using open without explicitly specifying an encoding (unspecified-encoding)
f.write(class_line)
logger.debug(f"Unitxt task pointer written to {output_file}")

Check warning on line 126 in src/instructlab/eval/unitxt.py

View workflow job for this annotation

GitHub Actions / pylint

W1203: Use lazy % or .format() formatting in logging functions (logging-fstring-interpolation)
6 changes: 2 additions & 4 deletions tests/test_unitxt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ def test_unitxt():
try:
model_path = "instructlab/granite-7b-lab"
unitxt_recipe = "card=cards.wnli,template=templates.classification.multi_class.relation.default,max_train_instances=5,loader_limit=20,num_demos=3,demos_pool_size=10"
unitxt = UnitxtEvaluator(
model_path=model_path, unitxt_recipe=unitxt_recipe
)
unitxt = UnitxtEvaluator(model_path=model_path, unitxt_recipe=unitxt_recipe)
overall_score, single_scores = unitxt.run()
print(overall_score)
except Exception as exc:
Expand All @@ -19,4 +17,4 @@ def test_unitxt():


if __name__ == "__main__":
assert test_unitxt() == True
assert test_unitxt() == True

0 comments on commit 64aff53

Please sign in to comment.