From 67f1bb430ed9b9a037974eb52c55f31dfcef3b05 Mon Sep 17 00:00:00 2001 From: Aakriti Kinra <52823721+AakritiKinra@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:12:36 -0500 Subject: [PATCH] ruff changes --- llments/eval/factscore/abstain_detection.py | 22 +++---- llments/eval/factscore/atomic_facts.py | 60 +++++++------------- llments/eval/factscore/clm.py | 16 ++---- llments/eval/factscore/download_data.py | 14 ++--- llments/eval/factscore/factscorer.py | 25 +++----- llments/eval/factscore/lm.py | 28 +++------ llments/eval/factscore/npm.py | 38 ++++--------- llments/eval/factscore/openai_lm.py | 24 +++----- llments/eval/factscore/retrieval.py | 63 ++++++--------------- llments/eval/factscore/utils.py | 32 ++++------- 10 files changed, 101 insertions(+), 221 deletions(-) diff --git a/llments/eval/factscore/abstain_detection.py b/llments/eval/factscore/abstain_detection.py index d1291dd..b9cb8ac 100644 --- a/llments/eval/factscore/abstain_detection.py +++ b/llments/eval/factscore/abstain_detection.py @@ -1,6 +1,4 @@ -""" -Abstain Detection Module -""" +"""Abstain Detection Module.""" import numpy as np import re from typing import List @@ -20,8 +18,7 @@ ] def remove_citation(text: str) -> str: - """ - Remove citation references and fix specific starting phrases in the text. + """Remove citation references and fix specific starting phrases in the text. Args: text (str): The input text from which citations are to be removed. @@ -35,8 +32,7 @@ def remove_citation(text: str) -> str: return text def is_invalid_ppl(text: str) -> bool: - """ - Check if the text starts with any invalid phrases indicating insufficient information. + """Check if the text starts with any invalid phrases indicating insufficient information. Args: text (str): The input text to be checked. @@ -47,8 +43,7 @@ def is_invalid_ppl(text: str) -> bool: return np.any([text.lower().startswith(mention.lower()) for mention in invalid_ppl_mentions]) def is_invalid_paragraph_ppl(text: str) -> bool: - """ - Determine if a paragraph is invalid based on its content. + """Determine if a paragraph is invalid based on its content. A paragraph is considered invalid if it is empty or contains any invalid phrases. @@ -61,8 +56,7 @@ def is_invalid_paragraph_ppl(text: str) -> bool: return len(text.strip())==0 or np.any([mention.lower() in text.lower() for mention in invalid_ppl_mentions]) def perplexity_ai_abstain_detect(generation: str) -> bool: - """ - Detect if the AI generation should abstain based on perplexity analysis. + """Detect if the AI generation should abstain based on perplexity analysis. This function removes citations from the generation, checks if it starts with any invalid phrases, and verifies that all paragraphs contain valid information. @@ -88,8 +82,7 @@ def perplexity_ai_abstain_detect(generation: str) -> bool: return False def generic_abstain_detect(generation: str) -> bool: - """ - Detect if the generation should abstain based on generic abstain phrases. + """Detect if the generation should abstain based on generic abstain phrases. Args: generation (str): The generated text to be analyzed. @@ -100,8 +93,7 @@ def generic_abstain_detect(generation: str) -> bool: return generation.startswith("I'm sorry") or "provide more" in generation def is_response_abstained(generation: str, fn_type: str) -> bool: - """ - Determine if the response should be abstained based on the specified detection function type. + """Determine if the response should be abstained based on the specified detection function type. Args: generation (str): The generated text to be analyzed. diff --git a/llments/eval/factscore/atomic_facts.py b/llments/eval/factscore/atomic_facts.py index f0e2905..0f88e1e 100644 --- a/llments/eval/factscore/atomic_facts.py +++ b/llments/eval/factscore/atomic_facts.py @@ -1,6 +1,4 @@ -""" -Atomic Facts Module -""" +"""Atomic Facts Module.""" import json import numpy as np import re @@ -17,8 +15,7 @@ nltk.download("punkt") class AtomicFactGenerator: - """ - A generator class to convert AI-generated text into atomic facts. + """A generator class to convert AI-generated text into atomic facts. Attributes: nlp (spacy.lang.en.English): The spaCy language model for NLP tasks. @@ -34,8 +31,7 @@ def __init__( demon_dir: str = "/factscore_data", gpt3_cache_file: Optional[str] = None, ) -> None: - """ - Initialize the AtomicFactGenerator. + """Initialize the AtomicFactGenerator. Args: key_path (str, optional): Path to the OpenAI API key file. Defaults to "key.txt". @@ -56,16 +52,13 @@ def __init__( self.bm25 = BM25Okapi(tokenized_corpus) def save_cache(self) -> None: - """ - Save the OpenAI language model cache. - """ + """Save the OpenAI language model cache.""" self.openai_lm.save_cache() def run( self, generation: str, cost_estimate: Optional[Any] = None ) -> Tuple[List[Tuple[str, List[str]]], List[int]]: - """ - Convert the generation into a set of atomic facts. Return a total words cost if cost_estimate != None. + """Convert the generation into a set of atomic facts. Return a total words cost if cost_estimate != None. Args: generation (str): The AI-generated text to be processed. @@ -83,8 +76,7 @@ def run( def get_atomic_facts_from_paragraph( self, paragraphs: List[str], cost_estimate: Optional[Any] = None ) -> Tuple[List[Tuple[str, List[str]]], List[int]]: - """ - Extract atomic facts from a list of paragraphs. + """Extract atomic facts from a list of paragraphs. Args: paragraphs (List[str]): List of paragraph texts. @@ -148,8 +140,7 @@ def get_atomic_facts_from_paragraph( def get_init_atomic_facts_from_sentence( self, sentences: List[str], cost_estimate: Optional[Any] = None ) -> Any: - """ - Get the initial atomic facts from the sentences. Return a total words cost if cost_estimate != None. + """Get the initial atomic facts from the sentences. Return a total words cost if cost_estimate != None. Args: sentences (List[str]): List of sentences to process. @@ -208,8 +199,7 @@ def get_init_atomic_facts_from_sentence( def best_demos(query: str, bm25: BM25Okapi, demons_sents: List[str], k: int) -> List[str]: - """ - Retrieve the top matching demons for a given query using BM25. + """Retrieve the top matching demons for a given query using BM25. Args: query (str): The query sentence. @@ -225,8 +215,7 @@ def best_demos(query: str, bm25: BM25Okapi, demons_sents: List[str], k: int) -> return top_machings def text_to_sentences(text: str) -> List[str]: - """ - Transform InstructGPT output into a list of sentences. + """Transform InstructGPT output into a list of sentences. Args: text (str): The raw output text from InstructGPT. @@ -244,8 +233,7 @@ def text_to_sentences(text: str) -> List[str]: return sentences def normalize_answer(s: str) -> str: - """ - Lower text and remove punctuation, articles and extra whitespace. + """Lower text and remove punctuation, articles and extra whitespace. Args: s (str): The input string to normalize. @@ -270,8 +258,7 @@ def lower(text: str) -> str: MONTHS = [m.lower() for m in MONTHS] def is_num(text: str) -> bool: - """ - Check if the given text represents an integer number. + """Check if the given text represents an integer number. Args: text (str): The text to check. @@ -286,8 +273,7 @@ def is_num(text: str) -> bool: return False def is_date(text: str) -> bool: - """ - Determine if the given text represents a date. + """Determine if the given text represents a date. Args: text (str): The text to evaluate. @@ -302,8 +288,7 @@ def is_date(text: str) -> bool: return True def extract_numeric_values(text: str) -> set: - """ - Extract all unique numeric values from the text. + """Extract all unique numeric values from the text. Args: text (str): The input text. @@ -316,8 +301,7 @@ def extract_numeric_values(text: str) -> set: return set([value for value in numeric_values]) # convert the values to float and return as a list def detect_entities(text: str, nlp: spacy.lang.en.English) -> set: - """ - Detect relevant entities in the text using spaCy's NLP model. + """Detect relevant entities in the text using spaCy's NLP model. Args: text (str): The input text to analyze. @@ -358,8 +342,7 @@ def postprocess_atomic_facts( para_breaks: List[int], nlp: spacy.lang.en.English, ) -> Tuple[List[Tuple[str, List[str]]], List[int]]: - """ - Post-process atomic facts to fix minor issues and ensure consistency. + """Post-process atomic facts to fix minor issues and ensure consistency. Args: _atomic_facts (List[Tuple[str, List[str]]]): Initial list of atomic facts. @@ -429,8 +412,7 @@ def postprocess_atomic_facts( return new_atomic_facts, new_para_breaks def is_integer(s: str) -> bool: - """ - Check if the given string represents an integer. + """Check if the given string represents an integer. Args: s (str): The string to check. @@ -445,8 +427,7 @@ def is_integer(s: str) -> bool: return False def detect_initials(text: str) -> List[str]: - """ - Detect initials in the text. + """Detect initials in the text. Args: text (str): The input text. @@ -459,8 +440,7 @@ def detect_initials(text: str) -> List[str]: return [m for m in match] def fix_sentence_splitter(curr_sentences: List[str], initials: List[str]) -> List[str]: - """ - Fix sentence splitting issues caused by initials. + """Fix sentence splitting issues caused by initials. Args: curr_sentences (List[str]): List of current sentences. @@ -502,9 +482,7 @@ def fix_sentence_splitter(curr_sentences: List[str], initials: List[str]) -> Lis return sentences def main() -> None: - """ - Main function to demonstrate the usage of AtomicFactGenerator. - """ + """Main function to demonstrate the usage of AtomicFactGenerator.""" generator = AtomicFactGenerator("api.key", "demos", gpt3_cache_dir=None) atomic_facts, para_breaks = generator.run("Thierry Henry (born 17 August 1977) is a French professional football coach, pundit, and former player. He is considered one of the greatest strikers of all time, and one the greatest players of the Premier League history. He has been named Arsenal F.C's greatest ever player.\n\nHenry made his professional debut with Monaco in 1994 before signing for defending Serie A champions Juventus. However, limited playing time, coupled with disagreements with the club's hierarchy, led to him signing for Premier League club Arsenal for £11 million in 1999.") diff --git a/llments/eval/factscore/clm.py b/llments/eval/factscore/clm.py index 72eba39..74a5d8a 100644 --- a/llments/eval/factscore/clm.py +++ b/llments/eval/factscore/clm.py @@ -3,9 +3,7 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -""" -CLM (Causal Language Model) Module -""" +"""CLM (Causal Language Model) Module.""" import numpy as np import torch from tqdm import tqdm @@ -18,8 +16,7 @@ from factscore.lm import LM class CLM(LM): - """ - CLM (Causal Language Model) Class + """CLM (Causal Language Model) Class. This class extends the `LM` base class to provide functionalities specific to causal language modeling. It leverages pre-trained models from Hugging Face's Transformers library, enabling text generation @@ -38,8 +35,7 @@ def __init__( model_dir: str, cache_file: Optional[str] = None, ) -> None: - """ - Initialize the CLM (Causal Language Model) instance. + """Initialize the CLM (Causal Language Model) instance. Args: model_name (str): Name of the pre-trained language model. @@ -53,8 +49,7 @@ def __init__( super().__init__(cache_file) def load_model(self) -> None: - """ - Load the pre-trained causal language model and its tokenizer. + """Load the pre-trained causal language model and its tokenizer. This method loads the model from the specified directory, converts it to int8 precision for efficient GPU utilization, and initializes the tokenizer. @@ -75,8 +70,7 @@ def _generate( end_if_second_newline: bool = False, verbose: bool = False, ) -> Union[Tuple[str, np.ndarray], Tuple[List[str], List[np.ndarray]]]: - """ - Generate text based on input prompts using the causal language model. + """Generate text based on input prompts using the causal language model. Args: prompts (Union[str, List[str]]): Single prompt string or a list of prompt strings. diff --git a/llments/eval/factscore/download_data.py b/llments/eval/factscore/download_data.py index 6351b2d..650c81b 100644 --- a/llments/eval/factscore/download_data.py +++ b/llments/eval/factscore/download_data.py @@ -1,6 +1,4 @@ -""" -Download Data Module -""" +"""Download Data Module.""" import argparse import os import subprocess @@ -10,8 +8,7 @@ from typing import Tuple def download_file(_id: str, dest: str, cache_dir: str) -> None: - """ - Download a file from a given URL or Google Drive ID to the specified destination. + """Download a file from a given URL or Google Drive ID to the specified destination. Args: _id (str): The URL or Google Drive ID of the file to download. @@ -66,8 +63,7 @@ def smart_tokenizer_and_embedding_resize( tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, ) -> None: - """ - Resize tokenizer and embedding. + """Resize tokenizer and embedding. Note: This is the unoptimized version that may make your embedding size not be divisible by 64. @@ -98,8 +94,8 @@ def recover_instruct_llama( device: str = "cpu", test_recovered_model: bool = False ) -> Tuple[transformers.PreTrainedModel, transformers.PreTrainedTokenizer]: - """ - Recover an instruct LLaMA model by adding state dictionaries from a raw model to a recovered model. + """Recover an instruct LLaMA model by adding state dictionaries from a raw model to a recovered model. + Heavily adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/main/weight_diff.py. Args: diff --git a/llments/eval/factscore/factscorer.py b/llments/eval/factscore/factscorer.py index 5243afd..c88ba70 100644 --- a/llments/eval/factscore/factscorer.py +++ b/llments/eval/factscore/factscorer.py @@ -1,6 +1,4 @@ -""" -FactScore Scoring Module -""" +"""FactScore Scoring Module.""" import argparse import string import json @@ -18,8 +16,7 @@ from factscore.retrieval import DocDB, Retrieval class FactScorer: - """ - FactScorer Class + """FactScorer Class. This class integrates various language models and retrieval mechanisms to evaluate the factual accuracy of generated text. It supports different configurations, including retrieval-based models with ChatGPT, @@ -52,8 +49,7 @@ def __init__( abstain_detection_type: Optional[str] = None, batch_size: int = 256, ) -> None: - """ - Initialize the FactScorer instance. + """Initialize the FactScorer instance. Args: model_name (str, optional): Configuration of the language model to use. @@ -108,8 +104,7 @@ def __init__( self.lm = None def save_cache(self) -> None: - """ - Save caches for the language model, NPM instances, and retrieval instances. + """Save caches for the language model, NPM instances, and retrieval instances. This method ensures that any new entries added to the caches are persisted to their respective cache files to optimize performance and avoid redundant computations. @@ -128,8 +123,7 @@ def register_knowledge_source( db_path: Optional[str] = None, data_path: Optional[str] = None, ) -> None: - """ - Register a new knowledge source for retrieval. + """Register a new knowledge source for retrieval. This method initializes a new `DocDB` and `Retrieval` instance for the specified knowledge source. If NPM is included in the model configuration, it also initializes an `NPM` instance for the knowledge source. @@ -170,8 +164,7 @@ def print_cost_estimates( task: str, model: str, ) -> None: - """ - Print the estimated cost of OpenAI API usage based on the number of tokens. + """Print the estimated cost of OpenAI API usage based on the number of tokens. Args: total_words (int): Total number of words to be processed. @@ -207,8 +200,7 @@ def get_score( knowledge_source: Optional[str] = None, verbose: bool = False, ) -> Dict[str, Any]: - """ - Compute the factual accuracy score for the provided generations based on topics. + """Compute the factual accuracy score for the provided generations based on topics. This method retrieves relevant passages for each topic, generates or uses provided atomic facts, evaluates whether the generated content is supported by the retrieved knowledge, and computes @@ -333,8 +325,7 @@ def _get_score( knowledge_source: str, cost_estimate: Optional[str] = None, ) -> Union[List[Dict[str, bool]], int]: - """ - Compute support scores for each atomic fact based on the knowledge source. + """Compute support scores for each atomic fact based on the knowledge source. This internal method evaluates whether each atomic fact is supported by the retrieved passages using the configured language model and NPM (if applicable). diff --git a/llments/eval/factscore/lm.py b/llments/eval/factscore/lm.py index 2fbed02..76dc114 100644 --- a/llments/eval/factscore/lm.py +++ b/llments/eval/factscore/lm.py @@ -1,14 +1,11 @@ -""" -LM (Language Model) Base Class Module -""" +"""LM (Language Model) Base Class Module.""" import pickle import os import time from typing import Dict, Any class LM(object): - """ - LM (Language Model) Base Class + """LM (Language Model) Base Class. This class serves as a base for language models, managing caching of generated outputs and defining the interface for loading models and generating text. It handles the storage @@ -21,8 +18,7 @@ class LM(object): add_n (int): Counter for the number of new cache entries added. """ def __init__(self, cache_file: str) -> None: - """ - Initialize the LM (Language Model) instance. + """Initialize the LM (Language Model) instance. Args: cache_file (str): Path to the cache file for storing generated outputs. @@ -33,8 +29,7 @@ def __init__(self, cache_file: str) -> None: self.add_n = 0 def load_model(self) -> None: - """ - Load the language model and put it as self.model + """Load the language model and put it as self.model. Raises: NotImplementedError: If the method is not implemented by a subclass. @@ -48,8 +43,7 @@ def generate( max_sequence_length: int = 2048, max_output_length: int = 128, ) -> Any: - """ - Generate text based on the input prompt. + """Generate text based on the input prompt. Args: prompt (str): The input prompt to generate text from. @@ -79,9 +73,7 @@ def generate( return generated def save_cache(self) -> None: - """ - Save the current cache to the cache file. - """ + """Save the current cache to the cache file.""" if self.add_n == 0: return @@ -93,8 +85,7 @@ def save_cache(self) -> None: pickle.dump(self.cache_dict, f) def load_cache(self, allow_retry: bool = True) -> Dict[str, Any]: - """ - Load the cache from the cache file. + """Load the cache from the cache file. Args: allow_retry (bool, optional): Whether to retry loading the cache in case of errors. @@ -119,7 +110,4 @@ def load_cache(self, allow_retry: bool = True) -> Dict[str, Any]: time.sleep(5) else: cache = {} - return cache - - - + return cache \ No newline at end of file diff --git a/llments/eval/factscore/npm.py b/llments/eval/factscore/npm.py index 5666240..bf5b8d4 100644 --- a/llments/eval/factscore/npm.py +++ b/llments/eval/factscore/npm.py @@ -1,6 +1,4 @@ -""" -NPM Language Model Module -""" +"""NPM Language Model Module.""" import numpy as np import torch from collections import defaultdict @@ -11,8 +9,7 @@ from factscore.retrieval import Retrieval def softmax(x: np.ndarray) -> np.ndarray: - """ - Compute the softmax of a given NumPy array. + """Compute the softmax of a given NumPy array. Args: x (np.ndarray): Input array. @@ -23,8 +20,7 @@ def softmax(x: np.ndarray) -> np.ndarray: return(np.exp(x - np.max(x)) / np.exp(x - np.max(x)).sum()) class NPM(LM): - """ - NPM Language Model integrating BM25 retrieval with a masked language model. + """NPM Language Model integrating BM25 retrieval with a masked language model. This class extends the `LM` base class and provides functionalities to tokenize, encode, decode, and compute probabilities based on input topics and questions. @@ -44,8 +40,7 @@ def __init__( model_name: str, cache_file: str, ) -> None: - """ - Initialize the NPM language model. + """Initialize the NPM language model. Args: bm25 (Retrieval): BM25 retrieval instance for fetching relevant passages. @@ -71,8 +66,7 @@ def __init__( super().__init__(cache_file=cache_file) def load_model(self) -> None: - """ - Load the pre-trained masked language model and move it to GPU. + """Load the pre-trained masked language model and move it to GPU. Raises: OSError: If the model cannot be loaded. @@ -82,9 +76,7 @@ def load_model(self) -> None: self.model.eval() def save_cache(self) -> None: - """ - Save the cache for both the language model and BM25 retrieval. - """ + """Save the cache for both the language model and BM25 retrieval.""" super().save_cache() self.bm25.save_cache() @@ -94,8 +86,7 @@ def tokenize( skip_special_tokens: bool = False, padding: bool = True, ) -> Tuple[torch.LongTensor, torch.LongTensor]: - """ - Tokenize a list of texts with optional skipping of special tokens and padding. + """Tokenize a list of texts with optional skipping of special tokens and padding. Args: texts (List[str]): List of text strings to tokenize. @@ -124,8 +115,7 @@ def tokenize( return torch.LongTensor(_all_input_ids), torch.LongTensor(_all_attention_mask) def decode(self, input_ids: List[int]) -> str: - """ - Decode a list of input IDs back into a string. + """Decode a list of input IDs back into a string. Args: input_ids (List[int]): List of token IDs. @@ -141,8 +131,7 @@ def encode( skip_special_tokens: bool = False, gt_input_ids: Optional[List[int]] = None, ) -> List[Tuple[float, np.ndarray]]: - """ - Encode a list of texts into probabilities and hidden states. + """Encode a list of texts into probabilities and hidden states. Args: texts (List[str]): List of text strings to encode. @@ -186,8 +175,7 @@ def encode( return results def get_probabilty(self, topic: str, question: str) -> float: - """ - Compute the probability of a question given a topic using BM25 and the masked language model. + """Compute the probability of a question given a topic using BM25 and the masked language model. Args: topic (str): The topic string. @@ -251,8 +239,4 @@ def get_probabilty(self, topic: str, question: str) -> float: self.cache_dict[cache_key] = np.mean(probs) self.add_n += 1 - return self.cache_dict[cache_key] - - - - + return self.cache_dict[cache_key] \ No newline at end of file diff --git a/llments/eval/factscore/openai_lm.py b/llments/eval/factscore/openai_lm.py index 65b852f..3deb944 100644 --- a/llments/eval/factscore/openai_lm.py +++ b/llments/eval/factscore/openai_lm.py @@ -1,6 +1,4 @@ -""" -OpenAI Model Module -""" +"""OpenAI Model Module.""" from factscore.lm import LM import openai import sys @@ -11,8 +9,7 @@ from typing import Optional, List, Tuple, Dict, Any class OpenAIModel(LM): - """ - OpenAI Language Model Class + """OpenAI Language Model Class. This class extends the `LM` base class to interface with OpenAI's language models, including ChatGPT and InstructGPT. It handles API key management, text generation via the OpenAI API, and caching of @@ -30,8 +27,7 @@ def __init__( cache_file: Optional[str] = None, key_path: str = "api.key" ) -> None: - """ - Initialize the OpenAIModel instance. + """Initialize the OpenAIModel instance. Args: model_name (str): Name of the OpenAI model to use (e.g., "ChatGPT", "InstructGPT"). @@ -46,8 +42,7 @@ def __init__( super().__init__(cache_file) def load_model(self) -> None: - """ - Load the OpenAI API key and set the model name. + """Load the OpenAI API key and set the model name. This method reads the API key from the specified file and configures the OpenAI API client. It also sets the `model` attribute to the specified `model_name`. @@ -68,8 +63,7 @@ def _generate( max_sequence_length: int = 2048, max_output_length: int = 128 ) -> Tuple[str, Dict[str, Any]]: - """ - Generate text using the OpenAI API based on the input prompt. + """Generate text using the OpenAI API based on the input prompt. This method handles caching of generated outputs and interacts with the OpenAI API to produce text completions. It supports different models like ChatGPT and InstructGPT. @@ -115,8 +109,7 @@ def call_ChatGPT( temp: float = 0.7, verbose: bool = False ) -> Dict[str, Any]: - """ - Call the OpenAI ChatCompletion API to generate a response based on the input message. + """Call the OpenAI ChatCompletion API to generate a response based on the input message. Args: message (List[Dict[str, str]]): The input message(s) to send to the ChatCompletion API. @@ -164,8 +157,7 @@ def call_GPT3( echo: bool = False, verbose: bool = False ) -> Dict[str, Any]: - """ - Call the OpenAI GPT-3 API to generate a response based on the input prompt. + """Call the OpenAI GPT-3 API to generate a response based on the input prompt. This function handles API rate limits by implementing an exponential backoff retry mechanism. It continues to retry until a successful response is received or a critical error occurs. @@ -206,4 +198,4 @@ def call_GPT3( assert False logging.error("API error: %s (%d)" % (error, num_rate_errors)) time.sleep(np.power(2, num_rate_errors)) - return response + return response \ No newline at end of file diff --git a/llments/eval/factscore/retrieval.py b/llments/eval/factscore/retrieval.py index e8eda59..ddbab48 100644 --- a/llments/eval/factscore/retrieval.py +++ b/llments/eval/factscore/retrieval.py @@ -1,6 +1,4 @@ -""" -Document Database and Retrieval Module -""" +"""Document Database and Retrieval Module.""" import json import time import os @@ -17,8 +15,7 @@ MAX_LENGTH = 256 class DocDB(object): - """ - SQLite-backed Document Storage. + """SQLite-backed Document Storage. Implements get_doc_text(doc_id). @@ -28,8 +25,7 @@ class DocDB(object): add_n (int): Counter for the number of new documents added to the cache. """ def __init__(self, db_path: Optional[str] = None, data_path: Optional[str] = None) -> None: - """ - Initialize the DocDB instance. + """Initialize the DocDB instance. Connects to the SQLite database at `db_path`. If the database is empty, it builds the database from the provided `data_path`. @@ -54,22 +50,18 @@ def __init__(self, db_path: Optional[str] = None, data_path: Optional[str] = Non self.build_db(self.db_path, data_path) def __enter__(self) -> 'DocDB': - """ - Enter the runtime context related to this object. + """Enter the runtime context related to this object. Returns: DocDB: The DocDB instance itself. """ return self def __exit__(self, *args) -> None: - """ - Exit the runtime context and close the database connection. - """ + """Exit the runtime context and close the database connection.""" self.close() def path(self) -> str: - """ - Return the path to the file that backs this database. + """Return the path to the file that backs this database. Returns: str: Path to the SQLite database file. @@ -77,14 +69,11 @@ def path(self) -> str: return self.path def close(self) -> None: - """ - Close the connection to the database. - """ + """Close the connection to the database.""" self.connection.close() def build_db(self, db_path: str, data_path: str) -> None: - """ - Build the SQLite database from raw JSON data. + """Build the SQLite database from raw JSON data. This method reads raw data from `data_path`, processes it using a tokenizer, and inserts the documents into the SQLite database. @@ -147,8 +136,7 @@ def build_db(self, db_path: str, data_path: str) -> None: self.connection.close() def get_text_from_title(self, title: str) -> List[Dict[str, str]]: - """ - Fetch the raw text of the doc for 'doc_id'. + """Fetch the raw text of the doc for 'doc_id'. Args: title (str): The title of the document to fetch. @@ -170,8 +158,7 @@ def get_text_from_title(self, title: str) -> List[Dict[str, str]]: return results class Retrieval(object): - """ - Document Retrieval Class. + """Document Retrieval Class. Attributes: db (DocDB): Instance of the DocDB class for accessing documents. @@ -193,8 +180,7 @@ def __init__( retrieval_type: str = "gtr-t5-large", batch_size: Optional[int] = None ) -> None: - """ - Initialize the Retrieval instance. + """Initialize the Retrieval instance. Args: db (DocDB): Instance of the DocDB class for accessing documents. @@ -221,8 +207,7 @@ def __init__( self.add_n_embed = 0 def load_encoder(self) -> None: - """ - Load the sentence transformer encoder for embedding-based retrieval. + """Load the sentence transformer encoder for embedding-based retrieval. Raises: ValueError: If `batch_size` is not set for transformer-based retrieval. @@ -234,9 +219,7 @@ def load_encoder(self) -> None: assert self.batch_size is not None def load_cache(self) -> None: - """ - Load retrieval and embedding caches from the specified cache files. - """ + """Load retrieval and embedding caches from the specified cache files.""" if os.path.exists(self.cache_path): with open(self.cache_path, "r") as f: self.cache = json.load(f) @@ -249,9 +232,7 @@ def load_cache(self) -> None: self.embed_cache = {} def save_cache(self) -> None: - """ - Save retrieval and embedding caches to the specified cache files. - """ + """Save retrieval and embedding caches to the specified cache files.""" if self.add_n > 0: if os.path.exists(self.cache_path): with open(self.cache_path, "r") as f: @@ -277,8 +258,7 @@ def get_bm25_passages( passages: List[Dict[str, str]], k: int ) -> List[Dict[str, str]]: - """ - Retrieve top-k passages using BM25. + """Retrieve top-k passages using BM25. Args: topic (str): The topic associated with the query. @@ -306,8 +286,7 @@ def get_gtr_passages( passages: List[Dict[str, str]], k: int ) -> List[Dict[str, str]]: - """ - Retrieve top-k passages using transformer-based retrieval (e.g., GTR). + """Retrieve top-k passages using transformer-based retrieval (e.g., GTR). Args: topic (str): The topic associated with the query. @@ -340,8 +319,7 @@ def get_passages( question: str, k: int ) -> List[Dict[str, str]]: - """ - Retrieve top-k passages based on the topic and question using the specified retrieval method. + """Retrieve top-k passages based on the topic and question using the specified retrieval method. Args: topic (str): The topic associated with the query. @@ -364,9 +342,4 @@ def get_passages( self.add_n += 1 - return self.cache[cache_key] - - - - - + return self.cache[cache_key] \ No newline at end of file diff --git a/llments/eval/factscore/utils.py b/llments/eval/factscore/utils.py index 14360e0..d8642f4 100644 --- a/llments/eval/factscore/utils.py +++ b/llments/eval/factscore/utils.py @@ -3,15 +3,12 @@ # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -""" -Utilities Module -""" +"""Utilities Module.""" import torch from torch import nn def assert_all_approx_close(a: torch.Tensor, b: torch.Tensor, rtol: float, atol: float, count: int) -> None: - """ - Assert that all elements in tensors `a` and `b` are approximately close within the given tolerances. + """Assert that all elements in tensors `a` and `b` are approximately close within the given tolerances. If more than `count` elements are not close, print a message and perform an assertion. @@ -25,7 +22,6 @@ def assert_all_approx_close(a: torch.Tensor, b: torch.Tensor, rtol: float, atol: Raises: AssertionError: If the number of non-close elements exceeds `count`. """ - idx = torch.isclose(a.float(), b.float(), rtol, atol) sumval = (idx==0).sum().item() if sumval > count: @@ -36,10 +32,10 @@ def assert_all_approx_close(a: torch.Tensor, b: torch.Tensor, rtol: float, atol: print(e) def get_memory_footprint(model: nn.Module, return_buffers: bool = True) -> int: - """ - Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. + """Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. + Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the - PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 + PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2. Args: model (nn.Module): The PyTorch model to evaluate. @@ -57,8 +53,7 @@ def get_memory_footprint(model: nn.Module, return_buffers: bool = True) -> int: return mem def ـreplace_linear_with_int8linear(model: nn.Module, modules_to_not_convert: str = "lm_head") -> None: - """ - Recursively replace all `nn.Linear` layers in a model with `QuantizedLinearInt8`, except for specified modules. + """Recursively replace all `nn.Linear` layers in a model with `QuantizedLinearInt8`, except for specified modules. Args: model (nn.Module): The PyTorch model in which to replace linear layers. @@ -76,8 +71,8 @@ def ـreplace_linear_with_int8linear(model: nn.Module, modules_to_not_convert: s return class QuantizedLinearInt8(nn.Module): - """ - A simple but effictive implmenetion of Int8 quantization for linear layers. + """A simple but effictive implmenetion of Int8 quantization for linear layers. + The weights are quantized and stored as Int8, which saves ~50% of the gpu memory. During the forwared pass, the weights are de-quantized back to fp16 to do multiplication. @@ -98,8 +93,7 @@ class QuantizedLinearInt8(nn.Module): bias (Optional[torch.Tensor]): Bias tensor, if present. """ def __init__(self, linear_layer: nn.Linear) -> None: - """ - Initialize the QuantizedLinearInt8 layer. + """Initialize the QuantizedLinearInt8 layer. Args: linear_layer (nn.Linear): The original linear layer to be quantized. @@ -122,8 +116,7 @@ def __init__(self, linear_layer: nn.Linear) -> None: ) def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the quantized linear layer. + """Forward pass of the quantized linear layer. Args: x (torch.Tensor): Input tensor. @@ -136,8 +129,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def convert_model_to_int8_on_gpu(model: nn.Module, device: str) -> nn.Module: - """ - Quantize a PyTorch model to int8 and move it to the specified GPU device. + """Quantize a PyTorch model to int8 and move it to the specified GPU device. Args: model (nn.Module): The PyTorch model to be quantized. @@ -166,4 +158,4 @@ def convert_model_to_int8_on_gpu(model: nn.Module, device: str) -> nn.Module: memory_after_quantization = round(memory_after_quantization / 2**30, 2) # rounding for printing print(f'Quantization memory - before: {memory_before_quantization} GB, after: {memory_after_quantization} GB ({saving}% of the size before)') - return model + return model \ No newline at end of file