Skip to content

Commit

Permalink
ruff changes
Browse files Browse the repository at this point in the history
  • Loading branch information
AakritiKinra authored Dec 18, 2024
1 parent 873b2a3 commit 67f1bb4
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 221 deletions.
22 changes: 7 additions & 15 deletions llments/eval/factscore/abstain_detection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Abstain Detection Module
"""
"""Abstain Detection Module."""
import numpy as np
import re
from typing import List
Expand All @@ -20,8 +18,7 @@
]

def remove_citation(text: str) -> str:
"""
Remove citation references and fix specific starting phrases in the text.
"""Remove citation references and fix specific starting phrases in the text.
Args:
text (str): The input text from which citations are to be removed.
Expand All @@ -35,8 +32,7 @@ def remove_citation(text: str) -> str:
return text

def is_invalid_ppl(text: str) -> bool:
"""
Check if the text starts with any invalid phrases indicating insufficient information.
"""Check if the text starts with any invalid phrases indicating insufficient information.
Args:
text (str): The input text to be checked.
Expand All @@ -47,8 +43,7 @@ def is_invalid_ppl(text: str) -> bool:
return np.any([text.lower().startswith(mention.lower()) for mention in invalid_ppl_mentions])

def is_invalid_paragraph_ppl(text: str) -> bool:
"""
Determine if a paragraph is invalid based on its content.
"""Determine if a paragraph is invalid based on its content.
A paragraph is considered invalid if it is empty or contains any invalid phrases.
Expand All @@ -61,8 +56,7 @@ def is_invalid_paragraph_ppl(text: str) -> bool:
return len(text.strip())==0 or np.any([mention.lower() in text.lower() for mention in invalid_ppl_mentions])

def perplexity_ai_abstain_detect(generation: str) -> bool:
"""
Detect if the AI generation should abstain based on perplexity analysis.
"""Detect if the AI generation should abstain based on perplexity analysis.
This function removes citations from the generation, checks if it starts with any invalid phrases,
and verifies that all paragraphs contain valid information.
Expand All @@ -88,8 +82,7 @@ def perplexity_ai_abstain_detect(generation: str) -> bool:
return False

def generic_abstain_detect(generation: str) -> bool:
"""
Detect if the generation should abstain based on generic abstain phrases.
"""Detect if the generation should abstain based on generic abstain phrases.
Args:
generation (str): The generated text to be analyzed.
Expand All @@ -100,8 +93,7 @@ def generic_abstain_detect(generation: str) -> bool:
return generation.startswith("I'm sorry") or "provide more" in generation

def is_response_abstained(generation: str, fn_type: str) -> bool:
"""
Determine if the response should be abstained based on the specified detection function type.
"""Determine if the response should be abstained based on the specified detection function type.
Args:
generation (str): The generated text to be analyzed.
Expand Down
60 changes: 19 additions & 41 deletions llments/eval/factscore/atomic_facts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Atomic Facts Module
"""
"""Atomic Facts Module."""
import json
import numpy as np
import re
Expand All @@ -17,8 +15,7 @@
nltk.download("punkt")

class AtomicFactGenerator:
"""
A generator class to convert AI-generated text into atomic facts.
"""A generator class to convert AI-generated text into atomic facts.
Attributes:
nlp (spacy.lang.en.English): The spaCy language model for NLP tasks.
Expand All @@ -34,8 +31,7 @@ def __init__(
demon_dir: str = "/factscore_data",
gpt3_cache_file: Optional[str] = None,
) -> None:
"""
Initialize the AtomicFactGenerator.
"""Initialize the AtomicFactGenerator.
Args:
key_path (str, optional): Path to the OpenAI API key file. Defaults to "key.txt".
Expand All @@ -56,16 +52,13 @@ def __init__(
self.bm25 = BM25Okapi(tokenized_corpus)

def save_cache(self) -> None:
"""
Save the OpenAI language model cache.
"""
"""Save the OpenAI language model cache."""
self.openai_lm.save_cache()

def run(
self, generation: str, cost_estimate: Optional[Any] = None
) -> Tuple[List[Tuple[str, List[str]]], List[int]]:
"""
Convert the generation into a set of atomic facts. Return a total words cost if cost_estimate != None.
"""Convert the generation into a set of atomic facts. Return a total words cost if cost_estimate != None.
Args:
generation (str): The AI-generated text to be processed.
Expand All @@ -83,8 +76,7 @@ def run(
def get_atomic_facts_from_paragraph(
self, paragraphs: List[str], cost_estimate: Optional[Any] = None
) -> Tuple[List[Tuple[str, List[str]]], List[int]]:
"""
Extract atomic facts from a list of paragraphs.
"""Extract atomic facts from a list of paragraphs.
Args:
paragraphs (List[str]): List of paragraph texts.
Expand Down Expand Up @@ -148,8 +140,7 @@ def get_atomic_facts_from_paragraph(
def get_init_atomic_facts_from_sentence(
self, sentences: List[str], cost_estimate: Optional[Any] = None
) -> Any:
"""
Get the initial atomic facts from the sentences. Return a total words cost if cost_estimate != None.
"""Get the initial atomic facts from the sentences. Return a total words cost if cost_estimate != None.
Args:
sentences (List[str]): List of sentences to process.
Expand Down Expand Up @@ -208,8 +199,7 @@ def get_init_atomic_facts_from_sentence(


def best_demos(query: str, bm25: BM25Okapi, demons_sents: List[str], k: int) -> List[str]:
"""
Retrieve the top matching demons for a given query using BM25.
"""Retrieve the top matching demons for a given query using BM25.
Args:
query (str): The query sentence.
Expand All @@ -225,8 +215,7 @@ def best_demos(query: str, bm25: BM25Okapi, demons_sents: List[str], k: int) ->
return top_machings

def text_to_sentences(text: str) -> List[str]:
"""
Transform InstructGPT output into a list of sentences.
"""Transform InstructGPT output into a list of sentences.
Args:
text (str): The raw output text from InstructGPT.
Expand All @@ -244,8 +233,7 @@ def text_to_sentences(text: str) -> List[str]:
return sentences

def normalize_answer(s: str) -> str:
"""
Lower text and remove punctuation, articles and extra whitespace.
"""Lower text and remove punctuation, articles and extra whitespace.
Args:
s (str): The input string to normalize.
Expand All @@ -270,8 +258,7 @@ def lower(text: str) -> str:
MONTHS = [m.lower() for m in MONTHS]

def is_num(text: str) -> bool:
"""
Check if the given text represents an integer number.
"""Check if the given text represents an integer number.
Args:
text (str): The text to check.
Expand All @@ -286,8 +273,7 @@ def is_num(text: str) -> bool:
return False

def is_date(text: str) -> bool:
"""
Determine if the given text represents a date.
"""Determine if the given text represents a date.
Args:
text (str): The text to evaluate.
Expand All @@ -302,8 +288,7 @@ def is_date(text: str) -> bool:
return True

def extract_numeric_values(text: str) -> set:
"""
Extract all unique numeric values from the text.
"""Extract all unique numeric values from the text.
Args:
text (str): The input text.
Expand All @@ -316,8 +301,7 @@ def extract_numeric_values(text: str) -> set:
return set([value for value in numeric_values]) # convert the values to float and return as a list

def detect_entities(text: str, nlp: spacy.lang.en.English) -> set:
"""
Detect relevant entities in the text using spaCy's NLP model.
"""Detect relevant entities in the text using spaCy's NLP model.
Args:
text (str): The input text to analyze.
Expand Down Expand Up @@ -358,8 +342,7 @@ def postprocess_atomic_facts(
para_breaks: List[int],
nlp: spacy.lang.en.English,
) -> Tuple[List[Tuple[str, List[str]]], List[int]]:
"""
Post-process atomic facts to fix minor issues and ensure consistency.
"""Post-process atomic facts to fix minor issues and ensure consistency.
Args:
_atomic_facts (List[Tuple[str, List[str]]]): Initial list of atomic facts.
Expand Down Expand Up @@ -429,8 +412,7 @@ def postprocess_atomic_facts(
return new_atomic_facts, new_para_breaks

def is_integer(s: str) -> bool:
"""
Check if the given string represents an integer.
"""Check if the given string represents an integer.
Args:
s (str): The string to check.
Expand All @@ -445,8 +427,7 @@ def is_integer(s: str) -> bool:
return False

def detect_initials(text: str) -> List[str]:
"""
Detect initials in the text.
"""Detect initials in the text.
Args:
text (str): The input text.
Expand All @@ -459,8 +440,7 @@ def detect_initials(text: str) -> List[str]:
return [m for m in match]

def fix_sentence_splitter(curr_sentences: List[str], initials: List[str]) -> List[str]:
"""
Fix sentence splitting issues caused by initials.
"""Fix sentence splitting issues caused by initials.
Args:
curr_sentences (List[str]): List of current sentences.
Expand Down Expand Up @@ -502,9 +482,7 @@ def fix_sentence_splitter(curr_sentences: List[str], initials: List[str]) -> Lis
return sentences

def main() -> None:
"""
Main function to demonstrate the usage of AtomicFactGenerator.
"""
"""Main function to demonstrate the usage of AtomicFactGenerator."""
generator = AtomicFactGenerator("api.key", "demos", gpt3_cache_dir=None)
atomic_facts, para_breaks = generator.run("Thierry Henry (born 17 August 1977) is a French professional football coach, pundit, and former player. He is considered one of the greatest strikers of all time, and one the greatest players of the Premier League history. He has been named Arsenal F.C's greatest ever player.\n\nHenry made his professional debut with Monaco in 1994 before signing for defending Serie A champions Juventus. However, limited playing time, coupled with disagreements with the club's hierarchy, led to him signing for Premier League club Arsenal for £11 million in 1999.")

Expand Down
16 changes: 5 additions & 11 deletions llments/eval/factscore/clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
CLM (Causal Language Model) Module
"""
"""CLM (Causal Language Model) Module."""
import numpy as np
import torch
from tqdm import tqdm
Expand All @@ -18,8 +16,7 @@
from factscore.lm import LM

class CLM(LM):
"""
CLM (Causal Language Model) Class
"""CLM (Causal Language Model) Class.
This class extends the `LM` base class to provide functionalities specific to causal language modeling.
It leverages pre-trained models from Hugging Face's Transformers library, enabling text generation
Expand All @@ -38,8 +35,7 @@ def __init__(
model_dir: str,
cache_file: Optional[str] = None,
) -> None:
"""
Initialize the CLM (Causal Language Model) instance.
"""Initialize the CLM (Causal Language Model) instance.
Args:
model_name (str): Name of the pre-trained language model.
Expand All @@ -53,8 +49,7 @@ def __init__(
super().__init__(cache_file)

def load_model(self) -> None:
"""
Load the pre-trained causal language model and its tokenizer.
"""Load the pre-trained causal language model and its tokenizer.
This method loads the model from the specified directory, converts it to int8 precision
for efficient GPU utilization, and initializes the tokenizer.
Expand All @@ -75,8 +70,7 @@ def _generate(
end_if_second_newline: bool = False,
verbose: bool = False,
) -> Union[Tuple[str, np.ndarray], Tuple[List[str], List[np.ndarray]]]:
"""
Generate text based on input prompts using the causal language model.
"""Generate text based on input prompts using the causal language model.
Args:
prompts (Union[str, List[str]]): Single prompt string or a list of prompt strings.
Expand Down
14 changes: 5 additions & 9 deletions llments/eval/factscore/download_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""
Download Data Module
"""
"""Download Data Module."""
import argparse
import os
import subprocess
Expand All @@ -10,8 +8,7 @@
from typing import Tuple

def download_file(_id: str, dest: str, cache_dir: str) -> None:
"""
Download a file from a given URL or Google Drive ID to the specified destination.
"""Download a file from a given URL or Google Drive ID to the specified destination.
Args:
_id (str): The URL or Google Drive ID of the file to download.
Expand Down Expand Up @@ -66,8 +63,7 @@ def smart_tokenizer_and_embedding_resize(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
) -> None:
"""
Resize tokenizer and embedding.
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
Expand Down Expand Up @@ -98,8 +94,8 @@ def recover_instruct_llama(
device: str = "cpu",
test_recovered_model: bool = False
) -> Tuple[transformers.PreTrainedModel, transformers.PreTrainedTokenizer]:
"""
Recover an instruct LLaMA model by adding state dictionaries from a raw model to a recovered model.
"""Recover an instruct LLaMA model by adding state dictionaries from a raw model to a recovered model.
Heavily adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/main/weight_diff.py.
Args:
Expand Down
Loading

0 comments on commit 67f1bb4

Please sign in to comment.