From 6a2f0b4881afa6c0a97e347335e3b118b12b717a Mon Sep 17 00:00:00 2001 From: mpc Date: Fri, 25 Oct 2024 11:13:50 +0100 Subject: [PATCH] re-added ruff --- notebooks/ragas_eval.ipynb | 20 ++++++++++---------- notebooks/ragas_synth.ipynb | 10 +++++----- pyproject.toml | 27 ++++++++++++++++++++++----- scripts/chunk_data.py | 6 ++++-- scripts/create_embeddings.py | 8 +++++--- scripts/evaluate.py | 23 ++++++++++++----------- scripts/extract_metadata.py | 5 ++--- scripts/fetch_eidc_metadata.py | 3 ++- scripts/fetch_supporting_docs.py | 21 +++++++++++++-------- scripts/run_rag_pipeline.py | 19 ++++++++++--------- scripts/upload_to_docstore.py | 13 ++++++++----- 11 files changed, 93 insertions(+), 62 deletions(-) diff --git a/notebooks/ragas_eval.ipynb b/notebooks/ragas_eval.ipynb index 56bcb43..395269f 100644 --- a/notebooks/ragas_eval.ipynb +++ b/notebooks/ragas_eval.ipynb @@ -37,15 +37,15 @@ } ], "source": [ + "import nest_asyncio\n", "import pandas as pd\n", - "from datasets import Dataset\n", - "from ragas import evaluate\n", - "from ragas.run_config import RunConfig\n", - "from langchain_community.embeddings import OllamaEmbeddings\n", - "from langchain_community.chat_models import ChatOllama\n", "import plotly.graph_objects as go\n", "import plotly.io as pio\n", - "import nest_asyncio" + "from datasets import Dataset\n", + "from langchain_community.chat_models import ChatOllama\n", + "from langchain_community.embeddings import OllamaEmbeddings\n", + "from ragas import evaluate\n", + "from ragas.run_config import RunConfig" ] }, { @@ -248,13 +248,13 @@ "outputs": [], "source": [ "from ragas.metrics import (\n", - " faithfulness,\n", + " answer_correctness,\n", " answer_relevancy,\n", + " answer_similarity,\n", + " context_entity_recall,\n", " context_precision,\n", " context_recall,\n", - " context_entity_recall,\n", - " answer_similarity,\n", - " answer_correctness,\n", + " faithfulness,\n", ")" ] }, diff --git a/notebooks/ragas_synth.ipynb b/notebooks/ragas_synth.ipynb index cf39b04..f8057b0 100644 --- a/notebooks/ragas_synth.ipynb +++ b/notebooks/ragas_synth.ipynb @@ -14,12 +14,12 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain_community.embeddings import OllamaEmbeddings\n", + "import nest_asyncio\n", "from langchain_community.chat_models import ChatOllama\n", - "from ragas.testset.generator import TestsetGenerator\n", - "from ragas.testset.evolutions import simple, reasoning, multi_context\n", + "from langchain_community.embeddings import OllamaEmbeddings\n", "from ragas.run_config import RunConfig\n", - "import nest_asyncio" + "from ragas.testset.evolutions import multi_context, reasoning, simple\n", + "from ragas.testset.generator import TestsetGenerator" ] }, { @@ -65,7 +65,7 @@ "metadata": {}, "outputs": [], "source": [ - "docs = [] # load a set of langchain documents to base the synthetic test set generation on" + "docs = [] # load a set of langchain docs to base the synthetic test set generation on" ] }, { diff --git a/pyproject.toml b/pyproject.toml index 7ecdec5..bb37d7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,6 @@ [project] name = "llm-eval" -dynamic = ["version"] - +version = "0.1.0" dependencies = [ "plotly == 5.24.1", "pandas == 2.2.3", @@ -18,7 +17,6 @@ dependencies = [ "ragas == 0.1.10", "nltk == 3.9.1", "nbformat == 4.2.0", - "ruff == 0.7.0", ] [project.optional-dependencies] @@ -26,6 +24,25 @@ jupyter = [ "ipykernel", "ipywidgets", ] +lint = [ + "ruff == 0.7.1", + "mypy == 1.13.0", +] +dev = [ + "llm-eval[jupyter,lint]" +] + +[tool.ruff.lint] +select = [ + "I", + "E", + "F", + "ANN" +] +fixable = ["ALL"] + +[tool.ruff] +line-length = 88 -[tool.setuptools] -py-modules = [] +[tool.ruff.lint.pydocstyle] +convention = "google" diff --git a/scripts/chunk_data.py b/scripts/chunk_data.py index ace111d..28707ed 100644 --- a/scripts/chunk_data.py +++ b/scripts/chunk_data.py @@ -1,6 +1,6 @@ -from typing import List, Dict import json from argparse import ArgumentParser +from typing import Any, Dict, List def chunk_value(value: str, chunk_size: int, overlap: int) -> List[str]: @@ -12,7 +12,9 @@ def chunk_value(value: str, chunk_size: int, overlap: int) -> List[str]: return chunks -def chunk_metadata_value(metada_value, chunk_size, overlap): +def chunk_metadata_value( + metada_value: str, chunk_size: int, overlap: int +) -> List[Dict[str, Any]]: chunks = chunk_value(metada_value["value"], chunk_size, overlap) return [ { diff --git a/scripts/create_embeddings.py b/scripts/create_embeddings.py index 2ad9cc9..7aa507c 100644 --- a/scripts/create_embeddings.py +++ b/scripts/create_embeddings.py @@ -1,15 +1,17 @@ import json -from sentence_transformers import SentenceTransformer from argparse import ArgumentParser + +from sentence_transformers import SentenceTransformer +from torch import Tensor from tqdm import tqdm -def create_embedding(text): +def create_embedding(text: str) -> Tensor: model = SentenceTransformer("all-MiniLM-L6-v2") return model.encode(text) -def main(input_file, output_file): +def main(input_file: str, output_file: str) -> None: with open(input_file) as input, open(output_file, "w") as output: data = json.load(input) for chunk in tqdm(data): diff --git a/scripts/evaluate.py b/scripts/evaluate.py index d7ac98f..c130e96 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -1,23 +1,24 @@ +import json from argparse import ArgumentParser + +import nest_asyncio import pandas as pd -from datasets import Dataset -from ragas import evaluate -from ragas.run_config import RunConfig -from langchain_community.embeddings import OllamaEmbeddings -from langchain_community.chat_models import ChatOllama import plotly.graph_objects as go import plotly.io as pio -import nest_asyncio +from datasets import Dataset +from langchain_community.chat_models import ChatOllama +from langchain_community.embeddings import OllamaEmbeddings +from ragas import evaluate from ragas.metrics import ( - faithfulness, + answer_correctness, answer_relevancy, + answer_similarity, + context_entity_recall, context_precision, context_recall, - context_entity_recall, - answer_similarity, - answer_correctness, + faithfulness, ) -import json +from ragas.run_config import RunConfig def main(eval_dataset: str, metric_output: str, image_output: str) -> None: diff --git a/scripts/extract_metadata.py b/scripts/extract_metadata.py index 8007d09..9bd4c3c 100644 --- a/scripts/extract_metadata.py +++ b/scripts/extract_metadata.py @@ -1,7 +1,6 @@ -from typing import List, Dict import json from argparse import ArgumentParser - +from typing import Dict, List METADATA_FIELDS = ["title", "description", "lineage"] @@ -30,7 +29,7 @@ def parse_eidc_metadata(file_path: str) -> List[Dict[str, str]]: return data -def main(input, output) -> None: +def main(input: str, output: str) -> None: data = parse_eidc_metadata(input) with open(output, "w") as f: json.dump(data, f, indent=4) diff --git a/scripts/fetch_eidc_metadata.py b/scripts/fetch_eidc_metadata.py index f411c16..0ab6297 100644 --- a/scripts/fetch_eidc_metadata.py +++ b/scripts/fetch_eidc_metadata.py @@ -1,7 +1,8 @@ -import requests import json from argparse import ArgumentParser +import requests + URL = "https://catalogue.ceh.ac.uk/eidc/documents" diff --git a/scripts/fetch_supporting_docs.py b/scripts/fetch_supporting_docs.py index 66e77ac..d95493b 100644 --- a/scripts/fetch_supporting_docs.py +++ b/scripts/fetch_supporting_docs.py @@ -1,15 +1,17 @@ -from argparse import ArgumentParser -import logging import json -from tqdm import tqdm -import requests +import logging import os +from argparse import ArgumentParser from typing import Dict, List + +import requests from dotenv import load_dotenv +from tqdm import tqdm logger = logging.getLogger(__name__) -def extract_ids(metadata_file: str): + +def extract_ids(metadata_file: str) -> List[str]: with open(metadata_file) as f: json_data = json.load(f) ids = [dataset["identifier"] for dataset in json_data["results"]] @@ -19,7 +21,8 @@ def extract_ids(metadata_file: str): def get_supporting_docs(eidc_id: str, user: str, password: str) -> List[Dict[str, str]]: try: res = requests.get( - f"https://legilo.eds-infra.ceh.ac.uk/{eidc_id}/documents", auth=(user, password) + f"https://legilo.eds-infra.ceh.ac.uk/{eidc_id}/documents", + auth=(user, password), ) json_data = res.json() docs = [] @@ -27,11 +30,13 @@ def get_supporting_docs(eidc_id: str, user: str, password: str) -> List[Dict[str docs.append({"id": eidc_id, "field": key, "value": val}) return docs except Exception as e: - logger.error(f"Failed to download supporting docs for dataset {eidc_id}", exc_info=e) + logger.error( + f"Failed to download supporting docs for dataset {eidc_id}", exc_info=e + ) return [] -def main(metadata_file: str, supporting_docs_file: str): +def main(metadata_file: str, supporting_docs_file: str) -> None: load_dotenv() user = os.getenv("username") password = os.getenv("password") diff --git a/scripts/run_rag_pipeline.py b/scripts/run_rag_pipeline.py index 91408ea..2c620e5 100644 --- a/scripts/run_rag_pipeline.py +++ b/scripts/run_rag_pipeline.py @@ -1,13 +1,14 @@ -from argparse import ArgumentParser import shutil +from argparse import ArgumentParser +from typing import Any, Dict, List, Tuple + +import pandas as pd from haystack import Pipeline -from haystack_integrations.document_stores.chroma import ChromaDocumentStore -from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever from haystack.components.builders import PromptBuilder -from haystack_integrations.components.generators.ollama.generator import OllamaGenerator from haystack.components.builders.answer_builder import AnswerBuilder -import pandas as pd - +from haystack_integrations.components.generators.ollama.generator import OllamaGenerator +from haystack_integrations.components.retrievers.chroma import ChromaQueryTextRetriever +from haystack_integrations.document_stores.chroma import ChromaDocumentStore TMP_DOC_PATH = ".tmp/doc-store" @@ -61,7 +62,7 @@ def build_rag_pipeline(model_name: str, collection_name: str) -> Pipeline: return rag_pipe -def run_query(query: str, pipeline: Pipeline): +def run_query(query: str, pipeline: Pipeline) -> Dict[str, Any]: return pipeline.run( { "retriever": {"query": query}, @@ -71,7 +72,7 @@ def run_query(query: str, pipeline: Pipeline): ) -def query_pipeline(questions, rag_pipe): +def query_pipeline(questions: List[str], rag_pipe: Pipeline) -> Tuple[str, List[str]]: answers = [] contexts = [] for q in questions: @@ -85,7 +86,7 @@ def query_pipeline(questions, rag_pipe): def main( test_data_file: str, ouput_file: str, doc_store_path: str, collection_name: str -): +) -> None: shutil.copytree(doc_store_path, TMP_DOC_PATH) rag_pipe = build_rag_pipeline("llama3.1", collection_name) diff --git a/scripts/upload_to_docstore.py b/scripts/upload_to_docstore.py index 7b547d7..9f1a880 100644 --- a/scripts/upload_to_docstore.py +++ b/scripts/upload_to_docstore.py @@ -1,14 +1,16 @@ -from argparse import ArgumentParser import json -import uuid -import shutil import os +import shutil +import uuid +from argparse import ArgumentParser import chromadb from chromadb.utils import embedding_functions -def main(input_file: str, output_path: str, collection_name: str, embedding_model: str): +def main( + input_file: str, output_path: str, collection_name: str, embedding_model: str +) -> None: if os.path.exists(output_path): shutil.rmtree(output_path) @@ -55,7 +57,8 @@ def main(input_file: str, output_path: str, collection_name: str, embedding_mode parser.add_argument( "-em", "--embedding_model", - help="Embedding model to use in the doc store (must be the same as the function used to create embeddings.)", + help="""Embedding model to use in the doc store (must be the same as the + function used to create embeddings.)""", default="all-MiniLM-L6-v2", ) args = parser.parse_args()