diff --git a/bertrend/BERTopicModel.py b/bertrend/BERTopicModel.py index 00590af..aab9d4d 100644 --- a/bertrend/BERTopicModel.py +++ b/bertrend/BERTopicModel.py @@ -68,19 +68,39 @@ class BERTopicModel: Utility class to manage and configure BERTopic instances with custom parameters. """ - def __init__(self, config_file: str | Path = BERTOPIC_DEFAULT_CONFIG_PATH): + def __init__(self, config: str | Path | dict = BERTOPIC_DEFAULT_CONFIG_PATH): """ Initialize a class from a TOML config file. - `config_file` can be: + `config` can be: - a `str` representing the TOML file - a `Path` to a TOML file + - a `dict` (with the same structure of the default config) containing values to be overridden compared to the default configuration To see file format and list of parameters: bertrend/config/topic_model_default_config.toml """ - self.config_file = config_file + if isinstance(config, str) or isinstance(config, Path): + try: + self.config = load_toml_config(config) + except Exception as e: + raise Exception(f"Failed to load TOML config: {e}") + elif isinstance(config, dict): + # load default config + self.config = load_toml_config(BERTOPIC_DEFAULT_CONFIG_PATH) + # overrides keys with provided dict + for section, settings in config.items(): + if section in config: + self.config[section].update( + settings + ) # Update the settings in that section + else: + self.config[section] = settings # If section doesn't exist, add it + else: + raise TypeError( + f"Config must be a string, Path or dict object, got: {type(config)}" + ) - # Load config file - self.config = self._load_config() + # Update config file (depending on language, etc.) + self._update_config() # Initialize models based on those parameters self._initialize_models() @@ -92,34 +112,34 @@ def __init__(self, config_file: str | Path = BERTOPIC_DEFAULT_CONFIG_PATH): ) ) - def _load_config(self) -> dict: + @classmethod + def get_default_config(cls) -> dict: + """Helper function to get default config. Useful to modify a s""" + return load_toml_config(BERTOPIC_DEFAULT_CONFIG_PATH) + + def _update_config(self): """ - Load the TOML config file as a dict when initializing the class. + Update the config file depending on initially loaded parameters. """ - config = load_toml_config(self.config_file) - # Handle specific parameters - # Transform ngram_range into tuple - if config["vectorizer_model"].get("ngram_range"): - config["vectorizer_model"]["ngram_range"] = tuple( - config["vectorizer_model"]["ngram_range"] + if self.config["vectorizer_model"].get("ngram_range"): + self.config["vectorizer_model"]["ngram_range"] = tuple( + self.config["vectorizer_model"]["ngram_range"] ) # Load stop words list - if config["vectorizer_model"].get("stop_words"): + if self.config["vectorizer_model"].get("stop_words"): stop_words = ( STOPWORDS - if config["global"]["language"] == "French" + if self.config["global"]["language"] == "French" else ENGLISH_STOPWORDS ) - config["vectorizer_model"]["stop_words"] = stop_words + self.config["vectorizer_model"]["stop_words"] = stop_words # BERTopic needs a "None" instead of an empty list, otherwise it'll attempt zeroshot topic modeling on an empty list - if not config["bertopic_model"].get("zeroshot_topic_list"): # empty list - config["bertopic_model"]["zeroshot_topic_list"] = None - - return config + if not self.config["bertopic_model"].get("zeroshot_topic_list"): # empty list + self.config["bertopic_model"]["zeroshot_topic_list"] = None def _initialize_models(self): self.umap_model = UMAP(**self.config["umap_model"]) diff --git a/bertrend/BERTrend.py b/bertrend/BERTrend.py index 32dc1db..42d4679 100644 --- a/bertrend/BERTrend.py +++ b/bertrend/BERTrend.py @@ -2,40 +2,50 @@ # See AUTHORS.txt # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. +import copy +import os import pickle -import shutil + +import dill # improvement to pickle from collections import defaultdict from pathlib import Path -from typing import Dict, Tuple, List, Any +from typing import Any import numpy as np import pandas as pd from bertopic import BERTopic from loguru import logger +from pandas import Timestamp from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity +from tqdm import tqdm from bertrend import ( MODELS_DIR, CACHE_PATH, BERTREND_DEFAULT_CONFIG_PATH, load_toml_config, + SIGNAL_EVOLUTION_DATA_DIR, ) from bertrend.BERTopicModel import BERTopicModel from bertrend.config.parameters import ( DOC_INFO_DF_FILE, TOPIC_INFO_DF_FILE, - DOC_GROUPS_FILE, - MODELS_TRAINED_FILE, - EMB_GROUPS_FILE, - HYPERPARAMS_FILE, BERTOPIC_SERIALIZATION, + SIGNAL_CLASSIF_LOWER_BOUND, + SIGNAL_CLASSIF_UPPER_BOUND, + BERTREND_FILE, + LANGUAGES, ) +from bertrend.services.embedding_service import EmbeddingService from bertrend.trend_analysis.weak_signals import ( _initialize_new_topic, update_existing_topic, _apply_decay_to_inactive_topics, + _filter_data, + _is_rising_popularity, + _create_dataframes, ) from bertrend.utils.data_loading import TEXT_COLUMN @@ -78,11 +88,11 @@ def __init__( # Variables related to time-based topic models # - topic_models: Dictionary of trained BERTopic models for each timestamp. - self.topic_models: Dict[pd.Timestamp, BERTopic] = {} + self.topic_models: dict[pd.Timestamp, BERTopic] = {} # - doc_groups: Dictionary of document groups for each timestamp. - self.doc_groups: Dict[pd.Timestamp, List[str]] = {} + self.doc_groups: dict[pd.Timestamp, list[str]] = {} # - emb_groups: Dictionary of document embeddings for each timestamp. - self.emb_groups: Dict[pd.Timestamp, np.ndarray] = {} + self.emb_groups: dict[pd.Timestamp, np.ndarray] = {} # Variables containing info about merged topics self.all_new_topics_df = None @@ -91,13 +101,13 @@ def __init__( # Variables containing info about topic popularity # - topic_sizes: Dictionary storing topic sizes and related information over time. - self.topic_sizes: Dict[int, Dict[str, Any]] = defaultdict( + self.topic_sizes: dict[int, dict[str, Any]] = defaultdict( lambda: defaultdict(list) ) # - topic_last_popularity: Dictionary storing the last known popularity of each topic. - self.topic_last_popularity: Dict[int, float] = {} + self.topic_last_popularity: dict[int, float] = {} # - topic_last_update: Dictionary storing the last update timestamp of each topic. - self.topic_last_update: Dict[int, pd.Timestamp] = {} + self.topic_last_update: dict[int, pd.Timestamp] = {} def _load_config(self) -> dict: """ @@ -112,9 +122,9 @@ def _train_by_period( group: pd.DataFrame, embedding_model: SentenceTransformer, embeddings: np.ndarray, - ) -> Tuple[ + ) -> tuple[ BERTopic, - List[str], + list[str], np.ndarray, ]: """ @@ -196,7 +206,7 @@ def _train_by_period( def train_topic_models( self, - grouped_data: Dict[pd.Timestamp, pd.DataFrame], + grouped_data: dict[pd.Timestamp, pd.DataFrame], embedding_model: SentenceTransformer, embeddings: np.ndarray, ): @@ -250,12 +260,13 @@ def train_topic_models( self._is_fitted = True + # Merge the newly obtained topic models with new ones # Update topic_models: Dictionary of trained BERTopic models for each timestamp. - self.topic_models = topic_models + self.topic_models.update(topic_models) # Update doc_groups: Dictionary of document groups for each timestamp. - self.doc_groups = doc_groups + self.doc_groups.update(doc_groups) # Update emb_groups: Dictionary of document embeddings for each timestamp. - self.emb_groups = emb_groups + self.emb_groups.update(emb_groups) logger.success("Finished training all topic models") def merge_all_models( @@ -263,6 +274,14 @@ def merge_all_models( min_similarity: int | None = None, ): """Merge together all topic models.""" + logger.debug( + f"{len(self.topic_models)} topic models to be merged:\n{list(self.topic_models.keys())}" + ) + if len(self.topic_models) < 2: # beginning of the process, no real merge needed + logger.warning("This function requires at least two topic models. Ignored") + self._are_models_merged = False + return + # Get default BERTrend config if argument is not provided if min_similarity is None: min_similarity = self.config["min_similarity"] @@ -279,6 +298,9 @@ def merge_all_models( } timestamps = sorted(topic_dfs.keys()) + + assert len(self.topic_models) >= 2 + merged_df_without_outliers = None all_merge_histories = [] all_new_topics = [] @@ -286,8 +308,8 @@ def merge_all_models( # TODO: tqdm merge_df_size_over_time = [] - for i, (current_timestamp, next_timestamp) in enumerate( - zip(timestamps[:-1], timestamps[1:]) + for i, (current_timestamp, next_timestamp) in tqdm( + enumerate(zip(timestamps[:-1], timestamps[1:])) ): df1 = topic_dfs[current_timestamp][ topic_dfs[current_timestamp]["Topic"] != -1 @@ -303,7 +325,7 @@ def merge_all_models( ) = _merge_models( df1, df2, - min_similarity=min_similarity, # SessionStateManager.get("min_similarity"), + min_similarity=min_similarity, timestamp=current_timestamp, ) elif not df2.empty: @@ -314,7 +336,7 @@ def merge_all_models( ) = _merge_models( merged_df_without_outliers, df2, - min_similarity=min_similarity, # SessionStateManager.get("min_similarity"), + min_similarity=min_similarity, timestamp=current_timestamp, ) else: @@ -322,7 +344,7 @@ def merge_all_models( all_merge_histories.append(merge_history) all_new_topics.append(new_topics) - merge_df_size_over_time = merge_df_size_over_time # SessionStateManager.get("merge_df_size_over_time") + merge_df_size_over_time = merge_df_size_over_time merge_df_size_over_time.append( ( current_timestamp, @@ -371,9 +393,10 @@ def calculate_signal_popularity( # Check if models are merged if not self._are_models_merged: - raise RuntimeWarning( + logger.error( "You must merge topic models first before computing signal popularity." ) + return topic_sizes = defaultdict(lambda: defaultdict(list)) topic_last_popularity = {} @@ -447,9 +470,167 @@ def calculate_signal_popularity( self.topic_last_popularity = topic_last_popularity self.topic_last_update = topic_last_update + def _compute_popularity_values_and_thresholds( + self, window_size: int, current_date: Timestamp + ) -> tuple[Timestamp, Timestamp, list, float, float]: + """ + Computes the popularity values and thresholds for the considered time window. + + Args: + window_size (int): The retrospective window size in days. + current_date (datetime): The current date selected by the user. + + Returns: + Tuple[Timestamp,Timestamp, list, float, float,]: + window_start, window_end indicates the start / end periods. + all_popularities_values + The q1 and q3 values representing the 10th and 90th percentiles of popularity values, + """ + + window_size_timedelta = pd.Timedelta(days=window_size) + granularity_timedelta = pd.Timedelta(days=self.config["granularity"]) + + current_date = pd.to_datetime(current_date).floor("D") # Floor to start of day + window_start = current_date - window_size_timedelta + window_end = current_date + granularity_timedelta + + # Calculate q1 and q3 values (we remove very low values of disappearing signals to not skew the thresholds) + all_popularity_values = [ + popularity + for topic, data in self.topic_sizes.items() + for timestamp, popularity in zip( + pd.to_datetime(data["Timestamps"]), data["Popularity"] + ) + if window_start <= timestamp <= current_date and popularity > 1e-5 + ] + + if all_popularity_values: + q1 = np.percentile(all_popularity_values, SIGNAL_CLASSIF_LOWER_BOUND) + q3 = np.percentile(all_popularity_values, SIGNAL_CLASSIF_UPPER_BOUND) + else: + q1, q3 = 0, 0 + + return window_start, window_end, all_popularity_values, q1, q3 + + def _classify_signals( + self, + window_start: pd.Timestamp, + window_end: pd.Timestamp, + q1: float, + q3: float, + rising_popularity_only: bool = True, + keep_documents: bool = True, + ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Classify signals into weak signal and strong signal dataframes. + + Args: + window_start (pd.Timestamp): The start timestamp of the window. + window_end (pd.Timestamp): The end timestamp of the window. + q1 (float): The 10th percentile of popularity values. + q3 (float): The 50th percentile of popularity values. + rising_popularity_only (bool): Whether to consider only rising popularity topics as weak signals. + keep_documents (bool): Whether to keep track of the documents or not. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + - noise_topics_df: DataFrame containing noise topics. + - weak_signal_topics_df: DataFrame containing weak signal topics. + - strong_signal_topics_df: DataFrame containing strong signal topics. + """ + noise_topics = [] + weak_signal_topics = [] + strong_signal_topics = [] + + sorted_topics = sorted(self.topic_sizes.items(), key=lambda x: x[0]) + + for topic, data in sorted_topics: + filtered_data = _filter_data(data, window_end, keep_documents) + if not filtered_data["Timestamps"]: + continue + + window_popularities = [ + (timestamp, popularity) + for timestamp, popularity in zip( + filtered_data["Timestamps"], filtered_data["Popularity"] + ) + if window_start <= timestamp <= window_end + ] + + if window_popularities: + latest_timestamp, latest_popularity = window_popularities[-1] + docs_count = ( + filtered_data["Docs_Count"][-1] + if filtered_data["Docs_Count"] + else 0 + ) + paragraphs_count = ( + filtered_data["Paragraphs_Count"][-1] + if filtered_data["Paragraphs_Count"] + else 0 + ) + source_diversity = ( + filtered_data["Source_Diversity"][-1] + if filtered_data["Source_Diversity"] + else 0 + ) + + topic_data = ( + topic, + latest_popularity, + latest_timestamp, + docs_count, + paragraphs_count, + source_diversity, + filtered_data, + ) + + if latest_popularity < q1: + noise_topics.append(topic_data) + elif q1 <= latest_popularity <= q3: + if rising_popularity_only: + if _is_rising_popularity(filtered_data, latest_timestamp): + weak_signal_topics.append(topic_data) + else: + noise_topics.append(topic_data) + else: + weak_signal_topics.append(topic_data) + else: + strong_signal_topics.append(topic_data) + + return _create_dataframes( + noise_topics, weak_signal_topics, strong_signal_topics, keep_documents + ) + + def classify_signals( + self, window_size: int, current_date: Timestamp + ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Classify signals into weak signal and strong signal dataframes for the considered time window. + + Args: + window_size (int): The retrospective window size in days. + current_date (datetime): The current date selected by the user. + + + Returns: + Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + - noise_topics_df: DataFrame containing noise topics. + - weak_signal_topics_df: DataFrame containing weak signal topics. + - strong_signal_topics_df: DataFrame containing strong signal topics. + """ + # Compute threshold values + window_start, window_end, all_popularity_values, q1, q3 = ( + self._compute_popularity_values_and_thresholds(window_size, current_date) + ) + + # Classify signals + noise_topics_df, weak_signal_topics_df, strong_signal_topics_df = ( + self._classify_signals(window_start, window_end, q1, q3) + ) + return noise_topics_df, weak_signal_topics_df, strong_signal_topics_df + def save_models(self, models_path: Path = MODELS_DIR): - if models_path.exists(): - shutil.rmtree(models_path) models_path.mkdir(parents=True, exist_ok=True) # Save topic models using the selected serialization type @@ -467,18 +648,12 @@ def save_models(self, models_path: Path = MODELS_DIR): topic_model.doc_info_df.to_pickle(model_dir / DOC_INFO_DF_FILE) topic_model.topic_info_df.to_pickle(model_dir / TOPIC_INFO_DF_FILE) - # Save topic model parameters - with open(CACHE_PATH / HYPERPARAMS_FILE, "wb") as f: - pickle.dump(self.topic_model, f) - # Save doc_groups file - with open(CACHE_PATH / DOC_GROUPS_FILE, "wb") as f: - pickle.dump(self.doc_groups, f) - # Save emb_groups file - with open(CACHE_PATH / EMB_GROUPS_FILE, "wb") as f: - pickle.dump(self.emb_groups, f) - # Save the models_trained flag - with open(CACHE_PATH / MODELS_TRAINED_FILE, "wb") as f: - pickle.dump(self._is_fitted, f) + # Serialize BERTrend (excluding topic models for separate reuse if needed) + topic_models_bak = copy.deepcopy(self.topic_models) + self.topic_models = None + with open(models_path / BERTREND_FILE, "wb") as f: + dill.dump(self, f) + self.topic_models = topic_models_bak logger.info(f"Models saved to: {models_path}") @@ -489,25 +664,15 @@ def restore_models(cls, models_path: Path = MODELS_DIR) -> "BERTrend": logger.info(f"Loading models from: {models_path}") - # Create BERTrend object - bertrend = cls() - - # load topic model parameters - with open(CACHE_PATH / HYPERPARAMS_FILE, "rb") as f: - bertrend.topic_model = pickle.load(f) - # load doc_groups file - with open(CACHE_PATH / DOC_GROUPS_FILE, "rb") as f: - bertrend.doc_groups = pickle.load(f) - # load emb_groups file - with open(CACHE_PATH / EMB_GROUPS_FILE, "rb") as f: - bertrend.emb_groups = pickle.load(f) - # load the models_trained flag - with open(CACHE_PATH / MODELS_TRAINED_FILE, "rb") as f: - bertrend._is_fitted = pickle.load(f) + # Unserialize BERTrend object (using dill as an improvement of pickle for complex objects) + with open(models_path / BERTREND_FILE, "rb") as f: + bertrend = dill.load(f) # Restore topic models using the selected serialization type topic_models = {} - for period_dir in models_path.iterdir(): + for period_dir in models_path.glob( + r"????-??-??" + ): # filter dir that are formatted YYYY-MM-DD if period_dir.is_dir(): topic_model = BERTopic.load(period_dir) @@ -527,9 +692,123 @@ def restore_models(cls, models_path: Path = MODELS_DIR) -> "BERTrend": return bertrend + def save_signal_evolution_data( + self, + window_size: int, + start_timestamp: pd.Timestamp, + end_timestamp: pd.Timestamp, + ) -> Path: + save_path = SIGNAL_EVOLUTION_DATA_DIR / f"retrospective_{window_size}_days" + os.makedirs(save_path, exist_ok=True) + + q1_values, q3_values, timestamps_over_time = [], [], [] + noise_dfs, weak_signal_dfs, strong_signal_dfs = [], [], [] + + for current_timestamp in tqdm( + pd.date_range( + start=start_timestamp, + end=end_timestamp, + freq=pd.Timedelta(days=self.config["granularity"]), + ), + desc="Processing timestamps", + ): + window_start, window_end, all_popularity_values, q1, q3 = ( + self._compute_popularity_values_and_thresholds( + window_size, current_timestamp + ) + ) + + noise_df, weak_signal_df, strong_signal_df = self._classify_signals( + window_start, window_end, q1, q3, keep_documents=False + ) + + noise_dfs.append(noise_df) + weak_signal_dfs.append(weak_signal_df) + strong_signal_dfs.append(strong_signal_df) + + timestamps_over_time.append(current_timestamp) + + # Save the grouped dataframes + with open(save_path / "noise_dfs_over_time.pkl", "wb") as f: + pickle.dump(noise_dfs, f) + with open(save_path / "weak_signal_dfs_over_time.pkl", "wb") as f: + pickle.dump(weak_signal_dfs, f) + with open(save_path / "strong_signal_dfs_over_time.pkl", "wb") as f: + pickle.dump(strong_signal_dfs, f) + + # Save the metadata + with open(save_path / "metadata.pkl", "wb") as f: + metadata = { + "window_size": window_size, + "granularity": self.config["granularity"], + "timestamps": timestamps_over_time, + "q1_values": q1_values, + "q3_values": q3_values, + } + pickle.dump(metadata, f) + + return save_path + + +def train_new_data( + new_data: pd.DataFrame, + bertrend_models_path: Path, + embedding_service: EmbeddingService, + granularity: int, + language: str, +) -> BERTrend: + """Helper function for processing new data (incremental trend analysis: + - loads a previous saved BERTrend model + - train a new topic model with the new data + - merge the models and update merge histories + - save the model and returns it + """ + logger.debug(f"Processing new data: {len(new_data)} items") + + # timestamp used to reference the model + reference_timestamp = pd.Timestamp(new_data["timestamp"].max().date()) + logger.info(f"Reference timestamp: {reference_timestamp}") + + # Restore previous models + try: + logger.info(f"Restoring previous BERTrend models from {bertrend_models_path}") + bertrend = BERTrend.restore_models(bertrend_models_path) + except: + logger.warning("Cannot restore previous models, creating new one") + # overrides default params + if language and language in LANGUAGES: + bertrend = BERTrend( + topic_model=BERTopicModel({"global": {"language": language}}) + ) + else: + bertrend = BERTrend(topic_model=BERTopicModel()) + bertrend.config["granularity"] = granularity + + # Embed new data + embeddings, token_strings, token_embeddings = embedding_service.embed( + texts=new_data[TEXT_COLUMN] + ) + embedding_model_name = embedding_service.embedding_model_name + + # Create topic model for new data + bertrend.train_topic_models( + {reference_timestamp: new_data}, + embeddings=embeddings, + embedding_model=embedding_model_name, + ) + + logger.info(f"BERTrend contains {len(bertrend.topic_models)} topic models") + # Save models + bertrend.save_models(models_path=bertrend_models_path) + + # Merge models + bertrend.merge_all_models() + + return bertrend + def _preprocess_model( - topic_model: BERTopic, docs: List[str], embeddings: np.ndarray + topic_model: BERTopic, docs: list[str], embeddings: np.ndarray ) -> pd.DataFrame: """ Preprocess a BERTopic model by extracting topic information, document groups, document embeddings, and URLs. @@ -581,7 +860,7 @@ def _preprocess_model( def _merge_models( df1: pd.DataFrame, df2: pd.DataFrame, min_similarity: float, timestamp: pd.Timestamp -) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: +) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: merged_df = df1.copy() merge_history = [] diff --git a/bertrend/__init__.py b/bertrend/__init__.py index df427dd..6647762 100644 --- a/bertrend/__init__.py +++ b/bertrend/__init__.py @@ -40,18 +40,13 @@ ) # Base dirs -BASE_DATA_PATH = BASE_PATH / "data" -BASE_CACHE_PATH = BASE_PATH / "cache" -BASE_OUTPUT_PATH = BASE_PATH / "output" +DATA_PATH = BASE_PATH / "data" +CACHE_PATH = BASE_PATH / "cache" +OUTPUT_PATH = BASE_PATH / "output" +CONFIG_PATH = BASE_PATH / "config" -FEED_BASE_PATH = BASE_DATA_PATH / "bertrend" / "feeds" -BERTREND_LOG_PATH = BASE_PATH / "logs" / "bertrend" -BERTREND_LOG_PATH.mkdir(parents=True, exist_ok=True) - -# Define directories -DATA_PATH = BASE_DATA_PATH / "bertrend" -OUTPUT_PATH = BASE_OUTPUT_PATH / "bertrend" -CACHE_PATH = BASE_CACHE_PATH / "bertrend" +FEED_BASE_PATH = DATA_PATH / "feeds" +BERTREND_LOG_PATH = BASE_PATH / "logs" # Weak signals MODELS_DIR = CACHE_PATH / "models" @@ -62,3 +57,6 @@ DATA_PATH.mkdir(parents=True, exist_ok=True) OUTPUT_PATH.mkdir(parents=True, exist_ok=True) CACHE_PATH.mkdir(parents=True, exist_ok=True) +CONFIG_PATH.mkdir(parents=True, exist_ok=True) +BERTREND_LOG_PATH.mkdir(parents=True, exist_ok=True) +MODELS_DIR.mkdir(parents=True, exist_ok=True) diff --git a/bertrend/config/parameters.py b/bertrend/config/parameters.py index 76f2243..03d3d4f 100644 --- a/bertrend/config/parameters.py +++ b/bertrend/config/parameters.py @@ -28,13 +28,9 @@ # File names STATE_FILE = "app_state.pkl" EMBEDDINGS_FILE = "embeddings.npy" -DOC_GROUPS_FILE = "doc_groups.pkl" -EMB_GROUPS_FILE = "emb_groups.pkl" -GRANULARITY_FILE = "granularity.pkl" -HYPERPARAMS_FILE = "hyperparams.pkl" +BERTREND_FILE = "bertrend.dill" DOC_INFO_DF_FILE = "doc_info_df.pkl" TOPIC_INFO_DF_FILE = "topic_info_df.pkl" -MODELS_TRAINED_FILE = "models_trained_flag.pkl" # Model file names ZEROSHOT_TOPICS_DATA_FILE = "zeroshot_topics_data.json" diff --git a/bertrend/demos/demos_utils/data_loading_component.py b/bertrend/demos/demos_utils/data_loading_component.py index 7dc5932..1801408 100644 --- a/bertrend/demos/demos_utils/data_loading_component.py +++ b/bertrend/demos/demos_utils/data_loading_component.py @@ -5,7 +5,6 @@ from pathlib import Path from tempfile import TemporaryDirectory -from typing import List import pandas as pd import streamlit as st @@ -49,8 +48,8 @@ def _process_uploaded_files( - files: List[UploadedFile], -) -> List[pd.DataFrame]: + files: list[UploadedFile], +) -> list[pd.DataFrame]: """Read a list of uploaded files and return a list of dataframes containing the associated data""" dataframes = [] with TemporaryDirectory() as tmpdir: @@ -68,8 +67,8 @@ def _process_uploaded_files( def _load_files( - files: List[Path], -) -> List[pd.DataFrame]: + files: list[Path], +) -> list[pd.DataFrame]: """Read a list of files from storage and return a list of dataframes containing the associated data""" dfs = [] for selected_file in files: diff --git a/bertrend/demos/demos_utils/icons.py b/bertrend/demos/demos_utils/icons.py index 3a19b06..cd56790 100644 --- a/bertrend/demos/demos_utils/icons.py +++ b/bertrend/demos/demos_utils/icons.py @@ -6,10 +6,15 @@ WARNING_ICON = ":material/warning:" ERROR_ICON = ":material/error:" INFO_ICON = ":material/info:" +EDIT_ICON = ":material/edit:" +ADD_ICON = ":material/add_circle:" +DELETE_ICON = ":material/delete:" SUCCESS_ICON = ":material/check:" SETTINGS_ICON = ":material/settings:" TOPIC_ICON = ":material/speaker_notes:" TREND_ICON = ":material/trending_up:" +MODELS_ICON = ":material/network_intelligence:" +EMBEDDING_ICON = ":material/memory:" SAVE_ICON = ":material/save:" TOPIC_EXPLORATION_ICON = ":material/explore:" TOPIC_VISUALIZATION_ICON = ":material/monitoring:" @@ -20,6 +25,12 @@ MODEL_TRAINING_ICON = ":material/cognition:" SERVER_STORAGE_ICON = ":material/database:" CLIENT_STORAGE_ICON = ":material/upload:" +UNHAPPY_ICON = ":material/sentiment_extremely_dissatisfied:" +TOGGLE_ON_ICON = ":material/toggle_on:" +TOGGLE_OFF_ICON = ":material/toggle_off:" +NOISE_ICON = ":material/signal_cellular_off:" +WEAK_SIGNAL_ICON = ":material/signal_cellular_1_bar:" +STRONG_SIGNAL_ICON = ":material/signal_cellular_3_bar:" JSON_ICON = "🧾" PARQUET_ICON = "📦️" diff --git a/bertrend/demos/demos_utils/parameters_component.py b/bertrend/demos/demos_utils/parameters_component.py index 020677c..d2fd466 100644 --- a/bertrend/demos/demos_utils/parameters_component.py +++ b/bertrend/demos/demos_utils/parameters_component.py @@ -85,7 +85,8 @@ def display_remote_embeddings(): ) -def display_bertopic_hyperparameters(): +def display_embedding_hyperparameters(): + """UI settings for embedding hyperparameters""" # Embedding model parameters with st.expander("Embedding Model Settings", expanded=False): register_widget("embedding_service_type") @@ -103,6 +104,8 @@ def display_bertopic_hyperparameters(): else: display_remote_embeddings() + +def display_bertopic_hyperparameters(): # BERTopic model parameters with st.expander("BERTopic Model Settings", expanded=False): # If BERTopic config is already in session state, use it diff --git a/bertrend/demos/demos_utils/state_utils.py b/bertrend/demos/demos_utils/state_utils.py index 37cc04c..e9667c9 100644 --- a/bertrend/demos/demos_utils/state_utils.py +++ b/bertrend/demos/demos_utils/state_utils.py @@ -2,7 +2,7 @@ # See AUTHORS.txt # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. -from typing import Any, Dict, Optional +from typing import Any import numpy as np import pandas as pd @@ -77,7 +77,7 @@ def get_or_set(key: str, default: Any) -> Any: return st.session_state[key] @staticmethod - def get_multiple(*keys: str) -> Dict[str, Any]: + def get_multiple(*keys: str) -> dict[str, Any]: return {key: st.session_state.get(key) for key in keys} @staticmethod @@ -90,10 +90,10 @@ def clear() -> None: st.session_state.clear() @staticmethod - def get_dataframe(key: str) -> Optional[pd.DataFrame]: + def get_dataframe(key: str) -> pd.DataFrame | None: df = st.session_state.get(key) return df if isinstance(df, pd.DataFrame) else None @staticmethod - def get_embeddings(key: str = "embeddings") -> Optional[np.ndarray]: + def get_embeddings(key: str = "embeddings") -> np.ndarray | None: return st.session_state.get(key) diff --git a/bertrend/demos/topic_analysis/app_utils.py b/bertrend/demos/topic_analysis/app_utils.py index 36b18c4..bc88b23 100644 --- a/bertrend/demos/topic_analysis/app_utils.py +++ b/bertrend/demos/topic_analysis/app_utils.py @@ -2,8 +2,6 @@ # See AUTHORS.txt # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. -from typing import List - import numpy as np import pandas as pd import streamlit as st @@ -47,7 +45,7 @@ def compute_topics_over_time( def print_docs_for_specific_topic( - df: pd.DataFrame, topics: List[int], topic_number: int + df: pd.DataFrame, topics: list[int], topic_number: int ): """ Print documents for a specific topic diff --git a/bertrend/demos/topic_analysis/demo_pages/explore_topics.py b/bertrend/demos/topic_analysis/demo_pages/explore_topics.py index e5c4cb0..0e44ece 100644 --- a/bertrend/demos/topic_analysis/demo_pages/explore_topics.py +++ b/bertrend/demos/topic_analysis/demo_pages/explore_topics.py @@ -6,7 +6,6 @@ import io import re import zipfile -from typing import List import pandas as pd import streamlit as st @@ -174,7 +173,7 @@ def get_representative_documents(top_n_docs: int): def display_source_distribution( - representative_df: pd.DataFrame, selected_sources: List[str] + representative_df: pd.DataFrame, selected_sources: list[str] ): """Display the distribution of sources in a pie chart.""" @@ -305,7 +304,9 @@ def _display_topic_description(filtered_df: pd.DataFrame): language_code=language_code, ) with st.container(border=True): - st.markdown(gpt_description) + st.markdown( + f"### {gpt_description['title']}\n{gpt_description['description']}" + ) def main(): diff --git a/bertrend/demos/topic_analysis/demo_pages/newsletters_generation.py b/bertrend/demos/topic_analysis/demo_pages/newsletters_generation.py index 88fcb49..f2c9b1f 100644 --- a/bertrend/demos/topic_analysis/demo_pages/newsletters_generation.py +++ b/bertrend/demos/topic_analysis/demo_pages/newsletters_generation.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. import inspect -from typing import Tuple import pandas as pd import streamlit as st @@ -37,7 +36,7 @@ def generate_newsletter_wrapper( df: pd.DataFrame, df_split: pd.DataFrame -) -> Tuple[str, str, str]: +) -> tuple[str, str, str]: """Wrapper function to generate newsletter based on user settings.""" top_n_topics = ( None diff --git a/bertrend/demos/topic_analysis/demo_pages/training_page.py b/bertrend/demos/topic_analysis/demo_pages/training_page.py index 58b7ac3..6eee769 100644 --- a/bertrend/demos/topic_analysis/demo_pages/training_page.py +++ b/bertrend/demos/topic_analysis/demo_pages/training_page.py @@ -22,6 +22,7 @@ ERROR_ICON, SETTINGS_ICON, INFO_ICON, + EMBEDDING_ICON, ) from bertrend.demos.demos_utils.messages import ( NO_EMBEDDINGS_WARNING_MESSAGE, @@ -38,6 +39,7 @@ ) from bertrend.demos.demos_utils.parameters_component import ( display_bertopic_hyperparameters, + display_embedding_hyperparameters, ) from bertrend.demos.weak_signals.visualizations_utils import PLOTLY_BUTTON_SAVE_CONFIG from bertrend.metrics.topic_metrics import compute_cluster_metrics @@ -175,6 +177,8 @@ def main(): # In the sidebar form with st.sidebar: st.header(SETTINGS_ICON + " Settings") + st.subheader(EMBEDDING_ICON + " Embedding Hyperparameters") + display_embedding_hyperparameters() display_bertopic_hyperparameters() # Load data diff --git a/bertrend/demos/weak_signals/app.py b/bertrend/demos/weak_signals/app.py index e860c21..68baa14 100644 --- a/bertrend/demos/weak_signals/app.py +++ b/bertrend/demos/weak_signals/app.py @@ -32,6 +32,7 @@ ANALYSIS_ICON, MODEL_TRAINING_ICON, DATA_LOADING_ICON, + EMBEDDING_ICON, ) from bertrend.demos.demos_utils.messages import ( NO_EMBEDDINGS_WARNING_MESSAGE, @@ -40,6 +41,7 @@ from bertrend.demos.demos_utils.parameters_component import ( display_bertopic_hyperparameters, display_bertrend_hyperparameters, + display_embedding_hyperparameters, ) from bertrend.BERTopicModel import BERTopicModel from bertrend.demos.weak_signals.messages import ( @@ -479,6 +481,8 @@ def main(): SessionStateManager.clear() # BERTopic Hyperparameters + st.subheader(EMBEDDING_ICON + " Embedding Hyperparameters") + display_embedding_hyperparameters() st.subheader(TOPIC_ICON + " BERTopic Hyperparameters") display_bertopic_hyperparameters() st.subheader(TREND_ICON + " BERTrend Hyperparameters") diff --git a/bertrend/demos/weak_signals/visualizations_utils.py b/bertrend/demos/weak_signals/visualizations_utils.py index 82428c9..ae540c7 100644 --- a/bertrend/demos/weak_signals/visualizations_utils.py +++ b/bertrend/demos/weak_signals/visualizations_utils.py @@ -2,7 +2,6 @@ # See AUTHORS.txt # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. -from typing import Dict import pandas as pd import streamlit as st @@ -10,9 +9,14 @@ from pandas import Timestamp from plotly import graph_objects as go -from bertrend import OUTPUT_PATH, SIGNAL_EVOLUTION_DATA_DIR -from bertrend.demos.demos_utils.icons import WARNING_ICON, SUCCESS_ICON, INFO_ICON -from bertrend.demos.weak_signals.messages import HTML_GENERATION_FAILED_WARNING +from bertrend import SIGNAL_EVOLUTION_DATA_DIR +from bertrend.demos.demos_utils.icons import ( + SUCCESS_ICON, + INFO_ICON, + STRONG_SIGNAL_ICON, + WEAK_SIGNAL_ICON, + NOISE_ICON, +) from bertrend.demos.demos_utils.state_utils import SessionStateManager from bertrend.config.parameters import ( MAX_WINDOW_SIZE, @@ -24,13 +28,10 @@ create_sankey_diagram_plotly, plot_newly_emerged_topics, plot_topics_for_model, - compute_popularity_values_and_thresholds, create_topic_size_evolution_figure, plot_topic_size_evolution, ) from bertrend.trend_analysis.weak_signals import ( - classify_signals, - save_signal_evolution_data, analyze_signal, ) @@ -81,56 +82,79 @@ def display_signal_categories_df( weak_signal_topics_df: pd.DataFrame, strong_signal_topics_df: pd.DataFrame, window_end: Timestamp, + columns=None, + column_order=None, ): """Display the dataframes associated to each signal category: noise, weak signal, strong signal.""" - columns = [ - "Topic", - "Sources", - "Source_Diversity", - "Representation", - "Latest_Popularity", - "Docs_Count", - "Paragraphs_Count", - "Latest_Timestamp", - "Documents", - ] + if columns is None: + columns = [ + "Topic", + "Sources", + "Source_Diversity", + "Representation", + "Latest_Popularity", + "Docs_Count", + "Paragraphs_Count", + "Latest_Timestamp", + "Documents", + ] + if column_order is None: + column_order = columns + + with st.expander(f":orange[{WEAK_SIGNAL_ICON} Weak Signals]", expanded=True): + st.subheader(":orange[Weak Signals]") + if not weak_signal_topics_df.empty: + displayed_df = weak_signal_topics_df[columns].sort_values( + by=["Latest_Popularity"], ascending=False + ) + displayed_df["Documents"] = displayed_df["Documents"].astype(str) + st.dataframe( + displayed_df, + # hide_index=True, + column_order=column_order, + ) - st.subheader(":grey[Noise]") - if not noise_topics_df.empty: - st.dataframe( - noise_topics_df.astype(str)[columns].sort_values( - by=["Topic", "Latest_Popularity"], ascending=[False, False] + else: + st.info( + f"No weak signals were detected at timestamp {window_end}.", + icon=INFO_ICON, ) - ) - else: - st.info( - f"No noisy signals were detected at timestamp {window_end}.", icon=INFO_ICON - ) - st.subheader(":orange[Weak Signals]") - if not weak_signal_topics_df.empty: - st.dataframe( - weak_signal_topics_df.astype(str)[columns].sort_values( - by=["Latest_Popularity"], ascending=True + with st.expander(f":green[{STRONG_SIGNAL_ICON} Strong Signals]", expanded=True): + st.subheader(":green[Strong Signals]") + if not strong_signal_topics_df.empty: + displayed_df = strong_signal_topics_df[columns].sort_values( + by=["Latest_Popularity"], ascending=False + ) + displayed_df["Documents"] = displayed_df["Documents"].astype(str) + st.dataframe( + displayed_df, + hide_index=True, + column_order=column_order, + ) + else: + st.info( + f"No strong signals were detected at timestamp {window_end}.", + icon=INFO_ICON, ) - ) - else: - st.info( - f"No weak signals were detected at timestamp {window_end}.", icon=INFO_ICON - ) - st.subheader(":green[Strong Signals]") - if not strong_signal_topics_df.empty: - st.dataframe( - strong_signal_topics_df.astype(str)[columns].sort_values( - by=["Topic", "Latest_Popularity"], ascending=[False, False] + with st.expander(f":grey[{NOISE_ICON} Noise]", expanded=True): + st.subheader(":grey[Noise]") + if not noise_topics_df.empty: + displayed_df = noise_topics_df[columns].sort_values( + by=["Latest_Popularity"], ascending=False + ) + displayed_df["Documents"] = displayed_df["Documents"].astype(str) + st.dataframe( + displayed_df, + hide_index=True, + column_order=column_order, + ) + else: + st.info( + f"No noisy signals were detected at timestamp {window_end}.", + icon=INFO_ICON, ) - ) - else: - st.info( - f"No strong signals were detected at timestamp {window_end}.", - icon=INFO_ICON, - ) def display_popularity_evolution(): @@ -166,16 +190,14 @@ def display_popularity_evolution(): key="current_date", ) - # Compute threshold values + # Compute threshold values and classify signals window_start, window_end, all_popularity_values, q1, q3 = ( - compute_popularity_values_and_thresholds( - bertrend.topic_sizes, window_size, granularity, current_date - ) + bertrend._compute_popularity_values_and_thresholds(window_size, current_date) ) # Classify signals - noise_topics_df, weak_signal_topics_df, strong_signal_topics_df = classify_signals( - bertrend.topic_sizes, window_start, window_end, q1, q3 + noise_topics_df, weak_signal_topics_df, strong_signal_topics_df = ( + bertrend._classify_signals(window_start, window_end, q1, q3) ) # Display threshold values for noise and strong signals @@ -225,13 +247,8 @@ def save_signal_evolution(): if st.button("Save Signal Evolution Data"): try: - save_path = save_signal_evolution_data( - all_merge_histories_df=all_merge_histories_df, - topic_sizes=dict(bertrend.topic_sizes), - topic_last_popularity=bertrend.topic_last_popularity, - topic_last_update=bertrend.topic_last_update, + save_path = bertrend.save_signal_evolution_data( window_size=SessionStateManager.get("window_size"), - granularity=granularity, start_timestamp=pd.Timestamp(start_date), end_timestamp=pd.Timestamp(end_date), ) @@ -270,7 +287,7 @@ def display_newly_emerged_topics(all_new_topics_df: pd.DataFrame) -> None: ) -def display_topics_per_timestamp(topic_models: Dict[pd.Timestamp, BERTopic]) -> None: +def display_topics_per_timestamp(topic_models: dict[pd.Timestamp, BERTopic]) -> None: """ Plot the topics discussed per source for each timestamp. @@ -298,44 +315,23 @@ def display_topics_per_timestamp(topic_models: Dict[pd.Timestamp, BERTopic]) -> st.dataframe(selected_model.topic_info_df, use_container_width=True) -def display_signal_analysis( - topic_number: int, output_file_name: str = "signal_llm.html" -): +def display_signal_analysis(topic_number: int): """Display a LLM-based analyis of a specific topic.""" - language = SessionStateManager.get("language") bertrend = SessionStateManager.get("bertrend") - granularity = SessionStateManager.get("granularity") - all_merge_histories_df = bertrend.all_merge_histories_df st.subheader("Signal Interpretation") with st.spinner("Analyzing signal..."): summary, analysis, formatted_html = analyze_signal( + bertrend, topic_number, SessionStateManager.get("current_date"), - all_merge_histories_df, - granularity, - language, ) - # Check if the HTML file was created successfully - output_file_path = OUTPUT_PATH / output_file_name - if output_file_path.exists(): - # Read the HTML file - with open(output_file_path, "r", encoding="utf-8") as file: - html_content = file.read() - # Display the HTML content - st.html(html_content) - else: - st.warning(HTML_GENERATION_FAILED_WARNING, icon=WARNING_ICON) - # Fallback to displaying markdown if HTML generation fails - col1, col2 = st.columns(spec=[0.5, 0.5], gap="medium") - with col1: - st.markdown(summary) - with col2: - st.markdown(analysis) + # Display the HTML content + st.html(formatted_html) -def retrieve_topic_counts(topic_models: Dict[pd.Timestamp, BERTopic]) -> None: +def retrieve_topic_counts(topic_models: dict[pd.Timestamp, BERTopic]) -> None: individual_model_topic_counts = [ (timestamp, model.topic_info_df["Topic"].max() + 1) for timestamp, model in topic_models.items() diff --git a/bertrend/llm_utils/newsletter_features.py b/bertrend/llm_utils/newsletter_features.py index 2f87a65..0374408 100644 --- a/bertrend/llm_utils/newsletter_features.py +++ b/bertrend/llm_utils/newsletter_features.py @@ -6,7 +6,6 @@ from pathlib import Path import os import locale -from typing import List, Tuple, Any # from md2pdf.core import md2pdf import markdown @@ -23,7 +22,6 @@ USER_GENERATE_TOPIC_LABEL_SUMMARIES, ) from bertrend.services.summarizer import Summarizer -from bertrend.services.summary.abstractive_summarizer import AbstractiveSummarizer from bertopic._bertopic import BERTopic from tqdm import tqdm @@ -39,7 +37,7 @@ def generate_newsletter( topic_model: BERTopic, df: pd.DataFrame, - topics: List[int], + topics: list[int], df_split: pd.DataFrame = None, top_n_topics: int = DEFAULT_TOP_N_TOPICS, top_n_docs: int = DEFAULT_TOP_N_DOCS, @@ -51,7 +49,7 @@ def generate_newsletter( improve_topic_description: bool = False, openai_model_name: str = None, nb_sentences: int = 3, -) -> Tuple[str, str, str]: +) -> tuple[str, str, str]: """Generates a newsletters based on a trained BERTopic model. Args: diff --git a/bertrend/llm_utils/openai_client.py b/bertrend/llm_utils/openai_client.py index a4140a7..ccaf7d2 100644 --- a/bertrend/llm_utils/openai_client.py +++ b/bertrend/llm_utils/openai_client.py @@ -4,7 +4,6 @@ # This file is part of BERTrend. import os -from typing import List, Dict from openai import OpenAI, AzureOpenAI, Timeout, Stream from loguru import logger @@ -100,7 +99,7 @@ def generate( def generate_from_history( self, - messages: List[Dict], + messages: list[dict], **kwargs, ) -> ChatCompletion | Stream[ChatCompletionChunk] | str: """Call openai model for generation. diff --git a/bertrend/metrics/temporal_metrics_embedding.py b/bertrend/metrics/temporal_metrics_embedding.py index 3774bc7..1117bb7 100644 --- a/bertrend/metrics/temporal_metrics_embedding.py +++ b/bertrend/metrics/temporal_metrics_embedding.py @@ -56,7 +56,6 @@ import itertools from scipy.sparse import csr_matrix from sklearn.preprocessing import normalize -from typing import List, Union, Tuple import re from bertrend import OUTPUT_PATH @@ -66,12 +65,12 @@ class TempTopic: def __init__( self, topic_model: BERTopic, - docs: List[str], - embeddings: List[List[float]], - word_embeddings: List[List[List[float]]], - token_strings: List[List[str]], - timestamps: Union[List[str], List[int]], - topics: List[int] = None, + docs: list[str], + embeddings: list[list[float]], + word_embeddings: list[list[list[float]]], + token_strings: list[list[str]], + timestamps: list[str] | list[int], + topics: list[int] = None, evolution_tuning: bool = True, global_tuning: bool = False, ): @@ -355,12 +354,12 @@ def _topics_over_time(self) -> pd.DataFrame: def _fuzzy_match_and_embed( self, phrase: str, - token_strings: List[List[str]], - token_embeddings: List[np.ndarray], + token_strings: list[list[str]], + token_embeddings: list[np.ndarray], topic_id: int, timestamp: str, window_size: int, - ) -> Tuple[str, np.ndarray]: + ) -> tuple[str, np.ndarray]: """ Matches a phrase to the most similar window in token_strings using fuzzy matching and returns the corresponding embedding. @@ -405,7 +404,7 @@ def _fuzzy_match_and_embed( def _log_failed_match( self, phrase: str, - token_strings: List[List[str]], + token_strings: list[list[str]], topic_id: int, timestamp: str, best_match: str, @@ -535,7 +534,7 @@ def _calculate_representation_embeddings( def calculate_temporal_representation_stability( self, window_size: int = 2, k: int = 1 - ) -> Tuple[pd.DataFrame, float]: + ) -> tuple[pd.DataFrame, float]: """ Calculates the Temporal Representation Stability (TRS) scores for each topic. @@ -637,7 +636,7 @@ def calculate_temporal_representation_stability( def calculate_topic_embedding_stability( self, window_size: int = 2 - ) -> Tuple[pd.DataFrame, float]: + ) -> tuple[pd.DataFrame, float]: """ Calculates the Temporal Topic Embedding Stability (TTES) scores for each topic. @@ -764,7 +763,7 @@ def calculate_overall_topic_stability( def find_similar_topic_pairs( self, similarity_threshold: float = 0.8 - ) -> List[List[Tuple[int, int, str]]]: + ) -> list[list[tuple[int, int, str]]]: """ Finds similar topic pairs based on cosine similarity. diff --git a/bertrend/metrics/topic_metrics.py b/bertrend/metrics/topic_metrics.py index 74fa33c..f7fdac8 100644 --- a/bertrend/metrics/topic_metrics.py +++ b/bertrend/metrics/topic_metrics.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. -from typing import List from itertools import combinations import pandas as pd @@ -15,8 +14,8 @@ def get_coherence_value( topic_model: BERTopic, - topics: List[int], - docs: List[str], + topics: list[int], + docs: list[str], coherence_score_type: str = "c_npmi", ) -> float: """ @@ -72,8 +71,8 @@ def get_coherence_value( def get_diversity_value( topic_model: BERTopic, - topics: List[List[str]], - docs: List[str], + topics: list[list[str]], + docs: list[str], diversity_score_type: str = "puw", topk: int = 5, ) -> float: @@ -160,7 +159,7 @@ def compute_cluster_metrics(bertopic: BERTopic, topics: list[int], dataset: list ) -def proportion_unique_words(topics: List[List[str]], top_k: int) -> float: +def proportion_unique_words(topics: list[list[str]], top_k: int) -> float: """ Compute the proportion of unique words. @@ -178,7 +177,7 @@ def proportion_unique_words(topics: List[List[str]], top_k: int) -> float: return puw -def pairwise_jaccard_diversity(topics: List[List[str]], top_k: int) -> float: +def pairwise_jaccard_diversity(topics: list[list[str]], top_k: int) -> float: """ Compute the average pairwise Jaccard distance between the topics. diff --git a/bertrend/resources/topic_model_config_example.toml b/bertrend/resources/topic_model_config_example.toml deleted file mode 100644 index 69acc3c..0000000 --- a/bertrend/resources/topic_model_config_example.toml +++ /dev/null @@ -1,20 +0,0 @@ -# Example of configuration file to be used for topic modelling - -# BERTopic Hyperparameters -[bertopic_parameters] -umap_n_components = 5 -umap_n_neighbors = 5 -umap_min_dist = 0.0 -hdbscan_min_cluster_size = 5 -hdbscan_min_samples = 5 -top_n_words = 8 -min_df = 1 -granularity = 2 -min_similarity = 0.7 -zeroshot_min_similarity = 0.5 -bertopic_serialization = "safetensors" # or pickle -mmr_diversity = 0.3 -outlier_reduction_strategy = "c-tf-idf" # or "embeddings" -zeroshot_topics = [] # default list of topics -language="English" # or French -representation_models=["MaximalMarginalRelevance"] # and / or "KeyBERTInspired", "OpenAI" diff --git a/bertrend/services/embedding_service.py b/bertrend/services/embedding_service.py index 2aca204..bd9e7e2 100644 --- a/bertrend/services/embedding_service.py +++ b/bertrend/services/embedding_service.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. import json -from typing import List, Tuple, Union, Literal +from typing import Literal import numpy as np import pandas as pd @@ -52,10 +52,10 @@ def __init__( self.embedding_model_name = model_name self.embedding_dtype = embedding_dtype - def embed(self, texts: Union[List[str], pd.Series], verbose: bool = False) -> Tuple[ + def embed(self, texts: list[str] | pd.Series, verbose: bool = False) -> tuple[ np.ndarray, - List[List[str]] | None, - List[np.ndarray] | None, + list[list[str]] | None, + list[np.ndarray] | None, ]: """ Embed a list of documents using a Sentence Transformer model. @@ -84,11 +84,11 @@ def embed(self, texts: Union[List[str], pd.Series], verbose: bool = False) -> Tu def _local_embed_documents( self, - texts: List[str], + texts: list[str], embedding_device: str = EMBEDDING_DEVICE, batch_size: int = EMBEDDING_BATCH_SIZE, max_seq_length: int = EMBEDDING_MAX_SEQ_LENGTH, - ) -> Tuple[np.ndarray, List[List[str]], List[np.ndarray]]: + ) -> tuple[np.ndarray, list[list[str]], list[np.ndarray]]: """ Embed a list of documents using a Sentence Transformer model. @@ -181,8 +181,8 @@ def _local_embed_documents( return embeddings, token_strings, token_embeddings def _remote_embed_documents( - self, texts: List[str], show_progress_bar: bool = True - ) -> Tuple[np.ndarray, None, None]: + self, texts: list[str], show_progress_bar: bool = True + ) -> tuple[np.ndarray, None, None]: """ Embed a list of documents using a Sentence Transformer model. diff --git a/bertrend/services/summarizer.py b/bertrend/services/summarizer.py index 3c3a17c..ed4866e 100644 --- a/bertrend/services/summarizer.py +++ b/bertrend/services/summarizer.py @@ -4,7 +4,6 @@ # This file is part of BERTrend. from abc import ABC, abstractmethod -from typing import List DEFAULT_SUMMARIZATION_RATIO = 0.2 DEFAULT_MAX_SENTENCES = 3 @@ -25,13 +24,13 @@ def generate_summary( def summarize_batch( self, - article_texts: List[str], + article_texts: list[str], max_sentences: int = DEFAULT_MAX_SENTENCES, max_words: int = DEFAULT_MAX_WORDS, max_length_ratio: float = DEFAULT_SUMMARIZATION_RATIO, prompt_language="fr", model_name=None, - ) -> List[str]: + ) -> list[str]: """Basic implementation of batch summarization. Can be overridden by subclasses.""" return [ self.generate_summary( diff --git a/bertrend/services/summary/abstractive_summarizer.py b/bertrend/services/summary/abstractive_summarizer.py index 85d9c0a..3f4b576 100644 --- a/bertrend/services/summary/abstractive_summarizer.py +++ b/bertrend/services/summary/abstractive_summarizer.py @@ -4,7 +4,6 @@ # This file is part of BERTrend. import re -from typing import List from transformers import AutoTokenizer, AutoModelForSeq2SeqLM @@ -32,7 +31,7 @@ def __init__(self, model_name=DEFAULT_ABSTRACTIVE_MODEL): def generate_summary(self, article_text, **kwargs) -> str: return self.summarize_batch([article_text])[0] - def summarize_batch(self, article_texts: List[str], **kwargs) -> List[str]: + def summarize_batch(self, article_texts: list[str], **kwargs) -> list[str]: inputs = self.tokenizer( [self.WHITESPACE_HANDLER(text) for text in article_texts], return_tensors="pt", diff --git a/bertrend/services/summary/extractive_summarizer.py b/bertrend/services/summary/extractive_summarizer.py index 55e1e1d..bcfb968 100644 --- a/bertrend/services/summary/extractive_summarizer.py +++ b/bertrend/services/summary/extractive_summarizer.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. -from typing import List, Optional, Callable +from typing import Callable import nltk import numpy as np @@ -33,7 +33,7 @@ nltk.download("punkt") -def _summarize_based_on_cos_scores(cos_scores, summary_size: int) -> List[int]: +def _summarize_based_on_cos_scores(cos_scores, summary_size: int) -> list[int]: """Summarizes "something" on the basis of a cosine similarity matrix. This approach may apply to text or a set of chunks. @@ -65,7 +65,7 @@ def _summarize_based_on_cos_scores(cos_scores, summary_size: int) -> List[int]: return summary_indices -def summarize_embeddings(embeddings: Tensor, summary_size: int) -> List[int]: +def summarize_embeddings(embeddings: Tensor, summary_size: int) -> list[int]: """Summarizes "something" on the basis of its embeddings' representation. This approach may apply to text or a set of chunks. @@ -115,11 +115,11 @@ def generate_summary( summary = self.summarize_text(text, max_sentences, max_length_ratio) return " ".join(summary) - def get_sentences_embeddings(self, sentences: List[str]) -> List[float]: + def get_sentences_embeddings(self, sentences: list[str]) -> list[float]: """Compute the sentence embeddings""" return self.sentence_transformer_model.encode(sentences, convert_to_tensor=True) - def get_sentences(self, text: str, use_spacy: bool = False) -> List[str]: + def get_sentences(self, text: str, use_spacy: bool = False) -> list[str]: """Return a list of sentences associated to a text (use of sentence tokenizer and some basic filtering) Parameters @@ -152,7 +152,7 @@ def get_sentences(self, text: str, use_spacy: bool = False) -> List[str]: sentences = sent_tokenize(text, language="french") return sentences - def get_chunks_embeddings(self, chunks: List[str]) -> List[float]: + def get_chunks_embeddings(self, chunks: list[str]) -> list[float]: """Constructs chunks embeddings as an average of the embeddings of the sentences contained in the chunks.""" # TODO: maybe we should start by summarizing each chunk in order to have chunks of same size? chunk_embeddings = [] @@ -168,9 +168,9 @@ def get_chunks_embeddings(self, chunks: List[str]) -> List[float]: def summarize_text( self, text: str, - max_nb: Optional[int] = DEFAULT_MAX_SENTENCES, - percentage: Optional[float] = DEFAULT_SUMMARIZATION_RATIO, - ) -> List[str]: + max_nb: int | None = DEFAULT_MAX_SENTENCES, + percentage: float | None = DEFAULT_SUMMARIZATION_RATIO, + ) -> list[str]: """Summarizes a text using the maximum number of sentences given as parameter Parameters @@ -212,9 +212,9 @@ def summarize_text( def summarize_chunks( self, - chunks: List[str], - max_nb_chunks: Optional[int] = DEFAULT_CHUNKS_NUMBER_SUMMARY, - ) -> List[str]: + chunks: list[str], + max_nb_chunks: int | None = DEFAULT_CHUNKS_NUMBER_SUMMARY, + ) -> list[str]: """Summarizes a list of text chunks using their embedding representation Parameters @@ -248,10 +248,10 @@ def summarize_text_with_additional_embeddings( self, text: str, function_to_compute_embeddings: Callable, - ratio_for_additional_embeddings: Optional[float] = 0.5, - max_nb: Optional[int] = DEFAULT_MAX_SENTENCES, - percentage: Optional[float] = DEFAULT_SUMMARIZATION_RATIO, - ) -> List[str]: + ratio_for_additional_embeddings: float | None = 0.5, + max_nb: int | None = DEFAULT_MAX_SENTENCES, + percentage: float | None = DEFAULT_SUMMARIZATION_RATIO, + ) -> list[str]: """Summarizes a text using the maximum number of sentences given as parameter. The summary is based on the combination of two embeddings: the "standard" sentence embeddings obtained by the sentence_transformers package, and an additional embedding related to the sentences, provided by @@ -315,14 +315,14 @@ def summarize_text_with_additional_embeddings( summary = [sentences[idx].strip() for idx in summary_indices] return summary - def check_paraphrase(self, sentences: List[str]): + def check_paraphrase(self, sentences: list[str]): """Given a list of sentences, returns a list of triplets with the format [score, id1, id2] indicating the degree of paraphrase between pairs of sentences.""" paraphrases = util.paraphrase_mining(self.sentence_transformer_model, sentences) return paraphrases @staticmethod - def to_string(summary: List[str]) -> str: + def to_string(summary: list[str]) -> str: """Basic display of summary""" text = "" for s in summary: diff --git a/bertrend/tests/test_topic_model.py b/bertrend/tests/test_topic_model.py index 7e7b267..6542e6d 100644 --- a/bertrend/tests/test_topic_model.py +++ b/bertrend/tests/test_topic_model.py @@ -204,7 +204,6 @@ def test_topic_model_initialization_custom_values(): def test_initialize_models_called(topic_model): """Test that internal models are initialized properly.""" - assert hasattr(topic_model, "config_file") assert hasattr(topic_model, "config") diff --git a/bertrend/topic_analysis/prompts.py b/bertrend/topic_analysis/prompts.py index 6e49c4d..ceb2268 100644 --- a/bertrend/topic_analysis/prompts.py +++ b/bertrend/topic_analysis/prompts.py @@ -16,9 +16,8 @@ 1. Un titre concis et informatif pour ce thème (maximum 10 mots) 2. Une description détaillée du thème (environ 100 mots) - Format de réponse : - ### Titre : [Votre titre ici] - [Votre description ici] + Réponse au format JSON: + title: [Votre titre ici], description: [Votre description ici] """ TOPIC_DESCRIPTION_PROMPT_EN = """As a topic analysis expert, your task is to generate a title and a description for a specific theme. @@ -33,9 +32,8 @@ 1. A concise and informative title for this theme (maximum 10 words) 2. A detailed description of the theme (about 100 words) -Response format: -### Title: [Your title here] -[Your description here] +Response in JSON format : +title: [Your title here], description: [Your description here] """ TOPIC_DESCRIPTION_PROMPT = { diff --git a/bertrend/topic_analysis/topic_description.py b/bertrend/topic_analysis/topic_description.py index 811c731..77bdd05 100644 --- a/bertrend/topic_analysis/topic_description.py +++ b/bertrend/topic_analysis/topic_description.py @@ -2,6 +2,8 @@ # See AUTHORS.txt # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. +import json + import pandas as pd from bertopic import BERTopic from loguru import logger @@ -16,9 +18,13 @@ def generate_topic_description( topic_number: int, filtered_docs: pd.DataFrame, language_code: str = "fr", -) -> str: - """Generates a LLM-based human-readable description of a topic""" +) -> dict: + """Generates a LLM-based human-readable description of a topic composed of a title and a description (as a dict)""" topic_words = topic_model.get_topic(topic_number) + if not topic_words: + logger.warning(f"No words found for topic number {topic_number}") + return {"title": "", "description": ""} + topic_representation = ", ".join( [word for word, _ in topic_words[:10]] ) # Get top 10 words @@ -41,12 +47,14 @@ def generate_topic_description( endpoint=LLM_CONFIG["endpoint"], model=LLM_CONFIG["model"], ) - return client.generate( + answer = client.generate( + response_format={"type": "json_object"}, user_prompt=prompt.format( topic_representation=topic_representation, docs_text=docs_text, - ) + ), ) + return json.loads(answer) except Exception as e: logger.error(f"Error calling OpenAI API: {e}") return f"Error generating description: {str(e)}" diff --git a/bertrend/trend_analysis/prompts.py b/bertrend/trend_analysis/prompts.py index f5bca2c..d9f6837 100644 --- a/bertrend/trend_analysis/prompts.py +++ b/bertrend/trend_analysis/prompts.py @@ -5,6 +5,8 @@ from pathlib import Path +from loguru import logger + from bertrend import OUTPUT_PATH # Global variables for prompts @@ -236,8 +238,8 @@ def get_prompt( return prompt -def save_html_output(model_output, output_file="signal_llm.html"): - """Function to parse the model's output and save as HTML""" +def clean_html_output(model_output) -> str: + """Function to parse the model's output""" # Clean the HTML content cleaned_html = model_output.strip() # Remove leading/trailing whitespace @@ -253,9 +255,14 @@ def save_html_output(model_output, output_file="signal_llm.html"): # Final strip to remove any remaining whitespace cleaned_html = cleaned_html.strip() + return cleaned_html + + +def save_html_output(html_output, output_file="signal_llm.html"): + """Function to save the model's output as HTML""" output_path = OUTPUT_PATH / output_file # Save the cleaned HTML with open(output_path, "w", encoding="utf-8") as file: - file.write(cleaned_html) - print(f"Cleaned HTML output saved to {output_path}") + file.write(html_output) + logger.debug(f"Cleaned HTML output saved to {output_path}") diff --git a/bertrend/trend_analysis/visualizations.py b/bertrend/trend_analysis/visualizations.py index f5a86b8..e6a4dbb 100644 --- a/bertrend/trend_analysis/visualizations.py +++ b/bertrend/trend_analysis/visualizations.py @@ -3,20 +3,12 @@ # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. -from typing import Dict, Tuple, List - -import numpy as np import pandas as pd import plotly.graph_objects as go from bertopic import BERTopic from pandas import Timestamp from plotly_resampler import FigureWidgetResampler -from bertrend.config.parameters import ( - SIGNAL_CLASSIF_LOWER_BOUND, - SIGNAL_CLASSIF_UPPER_BOUND, -) - # Visualization Settings SANKEY_NODE_PAD = 15 SANKEY_NODE_THICKNESS = 20 @@ -24,7 +16,7 @@ SANKEY_LINE_WIDTH = 0.5 -def plot_num_topics(topic_models: Dict[pd.Timestamp, BERTopic]) -> go.Figure: +def plot_num_topics(topic_models: dict[pd.Timestamp, BERTopic]) -> go.Figure: """ Plot the number of topics detected for each model. @@ -41,7 +33,7 @@ def plot_num_topics(topic_models: Dict[pd.Timestamp, BERTopic]) -> go.Figure: return fig_num_topics -def plot_size_outliers(topic_models: Dict[pd.Timestamp, BERTopic]) -> go.Figure: +def plot_size_outliers(topic_models: dict[pd.Timestamp, BERTopic]) -> go.Figure: """ Plot the size of the outlier topic for each model. @@ -131,7 +123,7 @@ def plot_topics_for_model(selected_model: BERTopic) -> go.Figure: return fig -def create_topic_size_evolution_figure(topic_sizes: Dict, topic_ids=None) -> go.Figure: +def create_topic_size_evolution_figure(topic_sizes: dict, topic_ids=None) -> go.Figure: fig = go.Figure() if topic_ids is None: @@ -169,57 +161,12 @@ def create_topic_size_evolution_figure(topic_sizes: Dict, topic_ids=None) -> go. return fig -def compute_popularity_values_and_thresholds( - topic_sizes, window_size: int, granularity: int, current_date -) -> Tuple[Timestamp, Timestamp, list, float, float]: - """ - Plot the evolution of topic sizes over time with colored overlays for signal regions. - - Args: - topic_sizes - window_size (int): The retrospective window size in days. - granularity (int): The granularity of the timestamps in days. - current_date (datetime): The current date selected by the user. - - Returns: - Tuple[Timestamp,Timestamp, list, float, float,]: - window_start, window_end indicates the start / end periods. - all_popularities_values - The q1 and q3 values representing the 10th and 90th percentiles of popularity values, - """ - - window_size_timedelta = pd.Timedelta(days=window_size) - granularity_timedelta = pd.Timedelta(days=granularity) - - current_date = pd.to_datetime(current_date).floor("D") # Floor to start of day - window_start = current_date - window_size_timedelta - window_end = current_date + granularity_timedelta - - # Calculate q1 and q3 values (we remove very low values of disappearing signals to not skew the thresholds) - all_popularity_values = [ - popularity - for topic, data in topic_sizes.items() - for timestamp, popularity in zip( - pd.to_datetime(data["Timestamps"]), data["Popularity"] - ) - if window_start <= timestamp <= current_date and popularity > 1e-5 - ] - - if all_popularity_values: - q1 = np.percentile(all_popularity_values, SIGNAL_CLASSIF_LOWER_BOUND) - q3 = np.percentile(all_popularity_values, SIGNAL_CLASSIF_UPPER_BOUND) - else: - q1, q3 = 0, 0 - - return window_start, window_end, all_popularity_values, q1, q3 - - def plot_topic_size_evolution( fig: go.Figure, current_date, window_start: Timestamp, window_end: Timestamp, - all_popularity_values: List[float], + all_popularity_values: list[float], q1: float, q3: float, ): diff --git a/bertrend/trend_analysis/weak_signals.py b/bertrend/trend_analysis/weak_signals.py index 8716609..41bbd30 100644 --- a/bertrend/trend_analysis/weak_signals.py +++ b/bertrend/trend_analysis/weak_signals.py @@ -3,46 +3,41 @@ # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. -import os -import pickle -from pathlib import Path -from typing import Dict, List, Tuple, Any - import numpy as np import pandas as pd import scipy from bertopic import BERTopic from loguru import logger -from tqdm import tqdm +from pandas import Timestamp from bertrend.llm_utils.openai_client import OpenAI_Client -from bertrend.config.parameters import ( - SIGNAL_CLASSIF_LOWER_BOUND, - SIGNAL_CLASSIF_UPPER_BOUND, +from bertrend import LLM_CONFIG +from bertrend.trend_analysis.prompts import ( + get_prompt, + save_html_output, + clean_html_output, ) -from bertrend import SIGNAL_EVOLUTION_DATA_DIR, LLM_CONFIG -from bertrend.trend_analysis.prompts import get_prompt, save_html_output def detect_weak_signals_zeroshot( - topic_models: Dict[pd.Timestamp, BERTopic], - zeroshot_topic_list: List[str], + topic_models: dict[Timestamp, BERTopic], + zeroshot_topic_list: list[str], granularity: int, decay_factor: float = 0.01, decay_power: float = 2, -) -> Dict[str, Dict[pd.Timestamp, Dict[str, any]]]: +) -> dict[str, dict[Timestamp, dict[str, any]]]: """ Detect weak signals based on the zero-shot list of topics to monitor. Args: - topic_models (Dict[pd.Timestamp, BERTopic]): Dictionary of BERTopic models for each timestamp. + topic_models (Dict[Timestamp, BERTopic]): Dictionary of BERTopic models for each timestamp. zeroshot_topic_list (List[str]): List of topics to monitor for weak signals. granularity (int): The granularity of the timestamps in days. decay_factor (float): The decay factor for exponential decay. decay_power (float): The decay power for exponential decay. Returns: - Dict[str, Dict[pd.Timestamp, Dict[str, any]]]: Dictionary of weak signal trends for each monitored topic. + Dict[str, Dict[Timestamp, Dict[str, any]]]: Dictionary of weak signal trends for each monitored topic. """ weak_signal_trends = {} @@ -339,177 +334,9 @@ def _apply_decay_to_inactive_topics( topic_last_popularity[topic] = decayed_popularity -def classify_signals( - topic_sizes: Dict[int, Dict[str, Any]], - window_start: pd.Timestamp, - window_end: pd.Timestamp, - q1: float, - q3: float, - rising_popularity_only: bool = True, - keep_documents: bool = True, -) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: - """ - Classify signals into weak signal and strong signal dataframes. - - Args: - topic_sizes (Dict[int, Dict[str, Any]]): Dictionary storing topic sizes and related information. - window_start (pd.Timestamp): The start timestamp of the window. - window_end (pd.Timestamp): The end timestamp of the window. - q1 (float): The 10th percentile of popularity values. - q3 (float): The 50th percentile of popularity values. - rising_popularity_only (bool): Whether to consider only rising popularity topics as weak signals. - keep_documents (bool): Whether to keep track of the documents or not. - - Returns: - Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: - - noise_topics_df: DataFrame containing noise topics. - - weak_signal_topics_df: DataFrame containing weak signal topics. - - strong_signal_topics_df: DataFrame containing strong signal topics. - """ - noise_topics = [] - weak_signal_topics = [] - strong_signal_topics = [] - - sorted_topics = sorted(topic_sizes.items(), key=lambda x: x[0]) - - for topic, data in sorted_topics: - filtered_data = _filter_data(data, window_end, keep_documents) - if not filtered_data["Timestamps"]: - continue - - window_popularities = [ - (timestamp, popularity) - for timestamp, popularity in zip( - filtered_data["Timestamps"], filtered_data["Popularity"] - ) - if window_start <= timestamp <= window_end - ] - - if window_popularities: - latest_timestamp, latest_popularity = window_popularities[-1] - docs_count = ( - filtered_data["Docs_Count"][-1] if filtered_data["Docs_Count"] else 0 - ) - paragraphs_count = ( - filtered_data["Paragraphs_Count"][-1] - if filtered_data["Paragraphs_Count"] - else 0 - ) - source_diversity = ( - filtered_data["Source_Diversity"][-1] - if filtered_data["Source_Diversity"] - else 0 - ) - - topic_data = ( - topic, - latest_popularity, - latest_timestamp, - docs_count, - paragraphs_count, - source_diversity, - filtered_data, - ) - - if latest_popularity < q1: - noise_topics.append(topic_data) - elif q1 <= latest_popularity <= q3: - if rising_popularity_only: - if _is_rising_popularity(filtered_data, latest_timestamp): - weak_signal_topics.append(topic_data) - else: - noise_topics.append(topic_data) - else: - weak_signal_topics.append(topic_data) - else: - strong_signal_topics.append(topic_data) - - return _create_dataframes( - noise_topics, weak_signal_topics, strong_signal_topics, keep_documents - ) - - -def save_signal_evolution_data( - all_merge_histories_df: pd.DataFrame, - topic_sizes: Dict[int, Dict[str, Any]], - topic_last_popularity: Dict[int, float], - topic_last_update: Dict[int, pd.Timestamp], - window_size: int, - granularity: int, - start_timestamp: pd.Timestamp, - end_timestamp: pd.Timestamp, -) -> Path: - window_size_timedelta = pd.Timedelta(days=window_size) - granularity_timedelta = pd.Timedelta(days=granularity) - - save_path = SIGNAL_EVOLUTION_DATA_DIR / f"retrospective_{window_size}_days" - os.makedirs(save_path, exist_ok=True) - - q1_values, q3_values, timestamps_over_time = [], [], [] - noise_dfs, weak_signal_dfs, strong_signal_dfs = [], [], [] - - for current_timestamp in tqdm( - pd.date_range( - start=start_timestamp, end=end_timestamp, freq=granularity_timedelta - ), - desc="Processing timestamps", - ): - window_end = current_timestamp + granularity_timedelta - window_start = window_end - granularity_timedelta - window_size_timedelta - - all_popularity_values = [ - popularity - for topic, data in topic_sizes.items() - for timestamp, popularity in zip(data["Timestamps"], data["Popularity"]) - if window_start <= timestamp <= current_timestamp and popularity > 1 ^ -5 - ] - - if all_popularity_values: - q1 = np.percentile(all_popularity_values, SIGNAL_CLASSIF_LOWER_BOUND) - q3 = np.percentile(all_popularity_values, SIGNAL_CLASSIF_UPPER_BOUND) - else: - q1, q3 = 0, 0 - - q1_values.append(q1) - q3_values.append(q3) - - noise_df, weak_signal_df, strong_signal_df = classify_signals( - topic_sizes, window_start, window_end, q1, q3, keep_documents=False - ) - - noise_dfs.append(noise_df) - weak_signal_dfs.append(weak_signal_df) - strong_signal_dfs.append(strong_signal_df) - - timestamps_over_time.append(current_timestamp) - - # Save the grouped dataframes - with open(save_path / "noise_dfs_over_time.pkl", "wb") as f: - pickle.dump(noise_dfs, f) - with open(save_path / "weak_signal_dfs_over_time.pkl", "wb") as f: - pickle.dump(weak_signal_dfs, f) - with open(save_path / "strong_signal_dfs_over_time.pkl", "wb") as f: - pickle.dump(strong_signal_dfs, f) - - # Save the metadata - with open(save_path / "metadata.pkl", "wb") as f: - metadata = { - "window_size": window_size, - "granularity": granularity, - "timestamps": timestamps_over_time, - "q1_values": q1_values, - "q3_values": q3_values, - } - pickle.dump(metadata, f) - - return save_path - - -def analyze_signal( - topic_number, current_date, all_merge_histories_df, granularity, language -): - topic_merge_rows = all_merge_histories_df[ - all_merge_histories_df["Topic1"] == topic_number +def analyze_signal(bertrend, topic_number: int, current_date: Timestamp): + topic_merge_rows = bertrend.all_merge_histories_df[ + bertrend.all_merge_histories_df["Topic1"] == topic_number ].sort_values("Timestamp") topic_merge_rows_filtered = topic_merge_rows[ topic_merge_rows["Timestamp"] <= current_date @@ -521,13 +348,15 @@ def analyze_signal( f"Timestamp: {row.Timestamp.strftime('%Y-%m-%d')}\n" f"Topic representation: {row.Representation1}\n" f"{' '.join(f'- {doc}' for doc in row.Documents1 if isinstance(doc, str))}\n" - f"Timestamp: {(row.Timestamp + pd.Timedelta(days=granularity)).strftime('%Y-%m-%d')}\n" + f"Timestamp: {(row.Timestamp + pd.Timedelta(days=bertrend.config['granularity'])).strftime('%Y-%m-%d')}\n" f"Topic representation: {row.Representation2}\n" f"{' '.join(f'- {doc}' for doc in row.Documents2 if isinstance(doc, str))}\n" for row in topic_merge_rows_filtered.itertuples() ] ) + language = bertrend.topic_model.config["global"]["language"] + try: openai_client = OpenAI_Client( api_key=LLM_CONFIG["api_key"], @@ -576,10 +405,7 @@ def analyze_signal( temperature=LLM_CONFIG["temperature"], max_tokens=LLM_CONFIG["max_tokens"], ) - - # Save the formatted HTML - save_html_output(formatted_html) - + formatted_html = clean_html_output(formatted_html) return summary, weak_signal_analysis, formatted_html except Exception as e: diff --git a/bertrend/utils/cache_utils.py b/bertrend/utils/cache_utils.py index b32a4e9..5097d2a 100644 --- a/bertrend/utils/cache_utils.py +++ b/bertrend/utils/cache_utils.py @@ -6,7 +6,7 @@ import hashlib import pickle from pathlib import Path -from typing import List, Any +from typing import Any import os @@ -20,7 +20,7 @@ def load_embeddings(cache_path: Path): return pickle.load(f_in) -def save_embeddings(embeddings: List, cache_path: Path): +def save_embeddings(embeddings: list, cache_path: Path): """Save embeddings as pickle""" with open(cache_path, "wb") as f_out: pickle.dump(embeddings, f_out) diff --git a/bertrend/utils/data_loading.py b/bertrend/utils/data_loading.py index 24714ef..7bbcadd 100644 --- a/bertrend/utils/data_loading.py +++ b/bertrend/utils/data_loading.py @@ -5,7 +5,7 @@ import gzip import re from pathlib import Path -from typing import Dict, Literal, List +from typing import Literal import pandas as pd from loguru import logger @@ -26,7 +26,7 @@ CITATION_COUNT_COL = "citation_count" -def find_compatible_files(path: Path, extensions: List[str]) -> List[Path]: +def find_compatible_files(path: Path, extensions: list[str]) -> list[Path]: return [f.relative_to(path) for f in path.rglob("*") if f.suffix[1:] in extensions] @@ -124,7 +124,7 @@ def split_data( def group_by_days( df: pd.DataFrame, day_granularity: int = 1 -) -> Dict[pd.Timestamp, pd.DataFrame]: +) -> dict[pd.Timestamp, pd.DataFrame]: """ Group a DataFrame by a specified number of days. diff --git a/bertrend_apps/common/crontab_utils.py b/bertrend_apps/common/crontab_utils.py index 11f68bb..96eb2f1 100644 --- a/bertrend_apps/common/crontab_utils.py +++ b/bertrend_apps/common/crontab_utils.py @@ -3,32 +3,83 @@ # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. import os +import re import subprocess import sys from pathlib import Path +from cron_descriptor import ( + Options, + CasingTypeEnum, + ExpressionDescriptor, + DescriptionTypeEnum, +) from loguru import logger from bertrend import BEST_CUDA_DEVICE, BERTREND_LOG_PATH, load_toml_config -def add_job_to_crontab(schedule, command, env_vars=""): - logger.info(f"Adding to crontab: {schedule} {command}") +def get_understandable_cron_description(cron_expression: str) -> str: + """Returns a human understandable crontab description.""" + options = Options() + options.casing_type = CasingTypeEnum.Sentence + options.use_24hour_time_format = True + options.locale_code = "fr_FR" + descriptor = ExpressionDescriptor(cron_expression, options) + return descriptor.get_description(DescriptionTypeEnum.FULL) + + +def add_job_to_crontab(schedule, command, env_vars="") -> bool: + """Add the specified job to the crontab.""" + logger.debug(f"Adding to crontab: {schedule} {command}") home = os.getenv("HOME") # Create crontab, add command - NB: we use the .bashrc to source all environment variables that may be required by the command cmd = f'(crontab -l; echo "{schedule} umask 002; source {home}/.bashrc; {env_vars} {command}" ) | crontab -' returned_value = subprocess.call(cmd, shell=True) # returns the exit code in unix - logger.info(f"Crontab updated with status {returned_value}") + return returned_value == 0 -def schedule_scrapping( - feed_cfg: Path, -): +def check_cron_job(pattern: str) -> bool: + """Check if a specific pattern (expressed as a regular expression) matches crontab entries.""" + try: + # Run `crontab -l` and capture the output + result = subprocess.run( + ["crontab", "-l"], capture_output=True, text=True, check=True + ) + + # Search for the regex pattern in the crontab output + if re.search(pattern, result.stdout): + return True + else: + return False + except subprocess.CalledProcessError: + # If crontab fails (e.g., no crontab for the user), return False + return False + + +def remove_from_crontab(pattern: str) -> bool: + """Removes from the crontab the job matching the provided pattern (expressed as a regular expression)""" + if not (check_cron_job(pattern)): + logger.warning("No job matching the provided pattern") + return False + try: + # Retrieve current crontab + output = subprocess.check_output( + f"crontab -l | grep -Ev {pattern} | crontab -", shell=True + ) + return output == 0 + except subprocess.CalledProcessError: + return False + + +def schedule_scrapping(feed_cfg: Path, user: str = None): """Schedule data scrapping on the basis of a feed configuration file""" data_feed_cfg = load_toml_config(feed_cfg) schedule = data_feed_cfg["data-feed"]["update_frequency"] id = data_feed_cfg["data-feed"]["id"] - command = f"{sys.prefix}/bin/python -m bertrend_apps.data_provider scrape-feed {feed_cfg.resolve()} > {BERTREND_LOG_PATH}/cron_feed_{id}.log 2>&1" + log_path = BERTREND_LOG_PATH if not user else BERTREND_LOG_PATH / "users" / user + log_path.mkdir(parents=True, exist_ok=True) + command = f"{sys.prefix}/bin/python -m bertrend_apps.data_provider scrape-feed {feed_cfg.resolve()} > {log_path}/cron_feed_{id}.log 2>&1" add_job_to_crontab(schedule, command, "") @@ -44,3 +95,19 @@ def schedule_newsletter( command = f"{sys.prefix}/bin/python -m bertrend_apps.newsletters newsletters {newsletter_cfg_path.resolve()} {data_feed_cfg_path.resolve()} > {BERTREND_LOG_PATH}/cron_newsletter_{id}.log 2>&1" env_vars = f"CUDA_VISIBLE_DEVICES={cuda_devices}" add_job_to_crontab(schedule, command, env_vars) + + +def check_if_scrapping_active_for_user(feed_id: str, user: str = None) -> bool: + """Checks if a given scrapping feed is active (registered in the crontab""" + if user: + return check_cron_job(rf"scrape-feed.*/users/{user}/{feed_id}_feed.toml") + else: + return check_cron_job(rf"scrape-feed.*/{feed_id}_feed.toml") + + +def remove_scrapping_for_user(feed_id: str, user: str = None): + """Removes from the crontab the job matching the provided feed_id""" + if user: + return remove_from_crontab(rf"scrape-feed.*/users/{user}/{feed_id}_feed.toml") + else: + return remove_from_crontab(rf"scrape-feed.*/{feed_id}_feed.toml") diff --git a/bertrend_apps/common/mail_utils.py b/bertrend_apps/common/mail_utils.py index b8357ad..4edf2d7 100644 --- a/bertrend_apps/common/mail_utils.py +++ b/bertrend_apps/common/mail_utils.py @@ -9,7 +9,6 @@ from email.mime.text import MIMEText from email.utils import COMMASPACE from pathlib import Path -from typing import List # Gmail API utils from google.auth.transport.requests import Request @@ -19,11 +18,11 @@ from googleapiclient.errors import HttpError from loguru import logger -from bertrend import BASE_DATA_PATH +from bertrend import BASE_PATH SCOPES = ["https://mail.google.com/"] # full access to mail API FROM = "wattelse.ai@gmail.com" -TOKEN_PATH = BASE_DATA_PATH / "gmail_token.json" +TOKEN_PATH = BASE_PATH / "gmail_token.json" DEFAULT_GMAIL_CREDENTIALS_PATH = ( Path(__file__).parent.parent / "config" / "gmail_credentials.json" ) @@ -60,7 +59,7 @@ def get_credentials( def send_email( credentials: Credentials, subject: str, - recipients: List[str], + recipients: list[str], content: str, content_type="html", ): diff --git a/bertrend_apps/data_provider/__init__.py b/bertrend_apps/data_provider/__init__.py index ae6e745..0c6306d 100644 --- a/bertrend_apps/data_provider/__init__.py +++ b/bertrend_apps/data_provider/__init__.py @@ -2,3 +2,8 @@ # See AUTHORS.txt # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. +# Define a pattern for a basic URL validation +URL_PATTERN = ( + r"^(https?://)?([a-z0-9-]+\.)+[a-z]{2,6}(:\d+)?(/[\w.-]*)*$|" + r"^(https?://)?(localhost|(\d{1,3}\.){3}\d{1,3})(:\d+)?(/[\w.-]*)*$" +) diff --git a/bertrend_apps/data_provider/__main__.py b/bertrend_apps/data_provider/__main__.py index 953cfc7..95c032a 100644 --- a/bertrend_apps/data_provider/__main__.py +++ b/bertrend_apps/data_provider/__main__.py @@ -15,6 +15,7 @@ from bertrend_apps.common.crontab_utils import schedule_scrapping from bertrend_apps.data_provider.arxiv_provider import ArxivProvider from bertrend_apps.data_provider.bing_news_provider import BingNewsProvider +from bertrend_apps.data_provider.curebot_provider import CurebotProvider from bertrend_apps.data_provider.google_news_provider import GoogleNewsProvider from bertrend_apps.data_provider.newscatcher_provider import NewsCatcherProvider @@ -23,6 +24,7 @@ PROVIDERS = { "arxiv": ArxivProvider, + "curebot": CurebotProvider, "google": GoogleNewsProvider, "bing": BingNewsProvider, "newscatcher": NewsCatcherProvider, @@ -46,7 +48,7 @@ def scrape( max_results: int = typer.Option( 50, help="maximum number of results per request" ), - save_path: str = typer.Option( + save_path: Path = typer.Option( None, help="Path for writing results. File is in jsonl format." ), language: str = typer.Option(None, help="Language filter"), @@ -65,7 +67,7 @@ def scrape( "to" date, formatted as YYYY-MM-DD max_results: int Maximum number of results per request - save_path: str + save_path: Path Path to the output file (jsonl format) language: str Language filter @@ -90,7 +92,7 @@ def auto_scrape( provider: str = typer.Option( "google", help="source for news [google, bing, newscatcher]" ), - save_path: str = typer.Option(None, help="Path for writing results."), + save_path: Path = typer.Option(None, help="Path for writing results."), language: str = typer.Option(None, help="Language filter"), ): """Scrape data from Arxiv, Google, Bing news or NewsCatcher (multiple requests from a configuration file: each line of the file shall be compliant with the following format: @@ -100,9 +102,11 @@ def auto_scrape( ---------- requests_file: str Text file containing the list of requests to be processed + max_results: int + Maximum number of results per request provider: str News data provider. Current authorized values [google, bing, newscatcher] - save_path: str + save_path: Path Path to the output file (jsonl format) language: str Language filter @@ -178,7 +182,7 @@ def _daterange(start_date, end_date, ndays): @app.command("scrape-feed") def scrape_from_feed( - feed_cfg: str = typer.Argument(help="Path of the data feed config file"), + feed_cfg: Path = typer.Argument(help="Path of the data feed config file"), ): """Scrape data from Arxiv, Google, Bing news or NewsCatcher on the basis of a feed configuration file""" data_feed_cfg = load_toml_config(feed_cfg) @@ -200,7 +204,7 @@ def scrape_from_feed( # Generate a query file with tempfile.NamedTemporaryFile() as query_file: - if provider == "arxiv": # already returns batches + if provider == "arxiv" or provider == "curebot": # already returns batches scrape( keywords=keywords, provider=provider, diff --git a/bertrend_apps/data_provider/arxiv_provider.py b/bertrend_apps/data_provider/arxiv_provider.py index 615d3c8..54b5d96 100644 --- a/bertrend_apps/data_provider/arxiv_provider.py +++ b/bertrend_apps/data_provider/arxiv_provider.py @@ -6,7 +6,6 @@ import itertools import os from datetime import datetime -from typing import List, Dict, Optional from collections import defaultdict import arxiv @@ -55,7 +54,7 @@ def get_articles( before: str, max_results: int, language: str = None, - ) -> List[Dict]: + ) -> list[dict]: """Requests the news data provider, collects a set of URLs to be parsed, return results as json lines. Parameters @@ -101,7 +100,7 @@ def get_articles( # add citations count return self.add_citations_count(results) - def _parse_entry(self, entry: arxiv.Result) -> Optional[Dict]: + def _parse_entry(self, entry: arxiv.Result) -> dict | None: """Parses a Arxiv entry""" try: id = entry.entry_id @@ -122,7 +121,7 @@ def _parse_entry(self, entry: arxiv.Result) -> Optional[Dict]: return None @wait(1) - def _request_semantic_scholar_chunk(self, chunk: List[Dict]): + def _request_semantic_scholar_chunk(self, chunk: list[dict]): """Get information from semantic scholar API per batch of articles IDs""" ids_list = ["URL:" + entry["id"] for entry in chunk] response = requests.post( @@ -133,7 +132,7 @@ def _request_semantic_scholar_chunk(self, chunk: List[Dict]): ) return [item for item in response.json() if item is not None] - def add_citations_count(self, entries: List[Dict]): + def add_citations_count(self, entries: list[dict]): """Uses the semantic_scholar API to get the number of counts per paper""" # split list into chunks and request semantic scholar chunks = [ diff --git a/bertrend_apps/data_provider/bing_news_provider.py b/bertrend_apps/data_provider/bing_news_provider.py index c9afcb0..2e19530 100644 --- a/bertrend_apps/data_provider/bing_news_provider.py +++ b/bertrend_apps/data_provider/bing_news_provider.py @@ -4,7 +4,6 @@ # This file is part of BERTrend. import urllib.parse -from typing import List, Dict, Optional import dateparser import feedparser @@ -36,7 +35,7 @@ def get_articles( before: str, max_results: int, language: str = None, - ) -> List[Dict]: + ) -> list[dict]: """Requests the news data provider, collects a set of URLs to be parsed, return results as json lines""" q = self._build_query(query, after, before) logger.info(f"Querying Bing: {q}") @@ -58,7 +57,7 @@ def _clean_url(self, bing_url) -> str: # fallback (the URL does not match the expected pattern) return bing_url - def _parse_entry(self, entry: Dict) -> Optional[Dict]: + def _parse_entry(self, entry: dict) -> dict | None: """Parses a Bing news entry, uses wait decorator to force delay between 2 successive calls""" try: link = entry["link"] diff --git a/bertrend_apps/data_provider/curebot_provider.py b/bertrend_apps/data_provider/curebot_provider.py index 2d7d4f4..698229b 100644 --- a/bertrend_apps/data_provider/curebot_provider.py +++ b/bertrend_apps/data_provider/curebot_provider.py @@ -2,12 +2,13 @@ # See AUTHORS.txt # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. +import re from pathlib import Path -from typing import List, Dict, Optional import pandas as pd from loguru import logger +from bertrend_apps.data_provider import URL_PATTERN from bertrend_apps.data_provider.data_provider import DataProvider import feedparser @@ -20,6 +21,8 @@ def __init__(self, curebot_export_file: Path = None, feed_url: str = None): self.data_file = curebot_export_file if self.data_file: self.df_dict = pd.read_excel(self.data_file, sheet_name=None, dtype=str) + else: + self.df_dict = None self.feed_url = feed_url def get_articles( @@ -29,20 +32,26 @@ def get_articles( before: str = None, max_results: int = None, language: str = "fr", - ) -> List[Dict]: + ) -> list[dict]: """Requests the news data provider, collects a set of URLs to be parsed, return results as json lines""" + if query and re.match(URL_PATTERN, query): + # if using a config file, the "query" field may contain the feed url + self.feed_url = query if self.feed_url: return self.parse_ATOM_feed() - entries = [] - for k in self.df_dict.keys(): - entries += self.df_dict[k].to_dict(orient="records") - results = [self._parse_entry(res) for res in entries] - return [ - res for res in results if res is not None - ] # sanity check to remove errors + if self.df_dict: + entries = [] + for k in self.df_dict.keys(): + entries += self.df_dict[k].to_dict(orient="records") + results = [self._parse_entry(res) for res in entries] + return [ + res for res in results if res is not None + ] # sanity check to remove errors + + return [] - def parse_ATOM_feed(self) -> List[Dict]: + def parse_ATOM_feed(self) -> list[dict]: feed = feedparser.parse(self.feed_url) # Initialize an empty list to store the entries entries = [] @@ -69,7 +78,7 @@ def parse_ATOM_feed(self) -> List[Dict]: return entries - def _parse_entry(self, entry: Dict) -> Optional[Dict]: + def _parse_entry(self, entry: dict) -> dict | None: """Parses a Curebot news entry""" try: # NB. we do not use the title from Gnews as it is sometimes truncated diff --git a/bertrend_apps/data_provider/data_provider.py b/bertrend_apps/data_provider/data_provider.py index f644784..9607c8a 100644 --- a/bertrend_apps/data_provider/data_provider.py +++ b/bertrend_apps/data_provider/data_provider.py @@ -6,7 +6,6 @@ import os from abc import ABC, abstractmethod from pathlib import Path -from typing import List, Dict, Tuple, Optional import jsonlines import langdetect @@ -55,7 +54,7 @@ def get_articles( before: str, max_results: int, language: str = None, - ) -> List[Dict]: + ) -> list[dict]: """Requests the news data provider, collects a set of URLs to be parsed, return results as json lines. Parameters @@ -79,8 +78,8 @@ def get_articles( pass def get_articles_batch( - self, queries_batch: List[List], max_results: int, language: str = None - ) -> List[Dict]: + self, queries_batch: list[list], max_results: int, language: str = None + ) -> list[dict]: """Requests the news data provider for a list of queries, collects a set of URLs to be parsed, return results as json lines""" articles = [] @@ -103,12 +102,13 @@ def parse_article(self, url: str) -> Article: article = self.article_parser.extract(url=url) return article - def store_articles(self, data: List[Dict], file_path: Path): + def store_articles(self, data: list[dict], file_path: Path): """Store articles to a specific path as json lines""" if not data: logger.error("No data to be stored!") return -1 - with jsonlines.open(file_path, "w") as writer: + with jsonlines.open(file_path, "a") as writer: + # append to existing file writer.write_all(data) logger.info(f"Data stored to {file_path} [{len(data)} entries].") @@ -121,7 +121,7 @@ def load_articles(self, file_path: Path) -> pd.DataFrame: return pd.DataFrame(data) @wait_if_seen_url(0.2) - def _get_text(self, url: str) -> Tuple[str, str]: + def _get_text(self, url: str) -> tuple[str, str]: """Extracts text and (clean) title from an article URL""" if any(ele in url for ele in BLACKLISTED_URL): logger.warning(f"Source of {url} is blacklisted!") @@ -136,7 +136,7 @@ def _get_text(self, url: str) -> Tuple[str, str]: logger.warning("Parsing of text failed with Goose3, trying newspaper3k") return self._get_text_alternate(url) - def _get_text_alternate(self, url: str) -> Tuple[str, str]: + def _get_text_alternate(self, url: str) -> tuple[str, str]: """Extracts text from an article URL""" logger.debug(f"Extracting text from {url} with newspaper3k") article = Article(url) @@ -156,11 +156,11 @@ def _filter_out_bad_text(self, text): return text @abstractmethod - def _parse_entry(self, entry: Dict) -> Optional[Dict]: + def _parse_entry(self, entry: dict) -> dict | None: """Parses a NewsCatcher news entry""" pass - def process_entries(self, entries: List, lang_filter: str = None): + def process_entries(self, entries: list, lang_filter: str = None): # Number of parallel jobs you want to run (adjust as needed) num_jobs = -1 # all available cpus diff --git a/bertrend_apps/data_provider/google_news_provider.py b/bertrend_apps/data_provider/google_news_provider.py index 45d5443..d00c9d6 100644 --- a/bertrend_apps/data_provider/google_news_provider.py +++ b/bertrend_apps/data_provider/google_news_provider.py @@ -4,7 +4,6 @@ # This file is part of BERTrend. import urllib.parse -from typing import List, Dict, Optional import dateparser from loguru import logger @@ -39,7 +38,7 @@ def get_articles( before: str = None, max_results: int = 50, language: str = "fr", - ) -> List[Dict]: + ) -> list[dict]: """Requests the news data provider, collects a set of URLs to be parsed, return results as json lines""" # FIXME: this may be blocked by google if language and language != "en": @@ -65,7 +64,7 @@ def _build_query(self, keywords: str, after: str = None, before: str = None) -> return query - def _parse_entry(self, entry: Dict) -> Optional[Dict]: + def _parse_entry(self, entry: dict) -> dict | None: """Parses a Google News entry""" try: # NB. we do not use the title from Gnews as it is sometimes truncated diff --git a/bertrend_apps/data_provider/newscatcher_provider.py b/bertrend_apps/data_provider/newscatcher_provider.py index 6e3b9ea..742dc8c 100644 --- a/bertrend_apps/data_provider/newscatcher_provider.py +++ b/bertrend_apps/data_provider/newscatcher_provider.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. import os -from typing import List, Dict, Optional import dateparser from loguru import logger @@ -35,7 +34,7 @@ def get_articles( before: str, max_results: int, language: str = None, - ) -> List[Dict]: + ) -> list[dict]: """Requests the news data provider, collects a set of URLs to be parsed, return results as json lines""" # Use the API to search articles @@ -47,7 +46,7 @@ def get_articles( entries = result["articles"][:max_results] return self.process_entries(entries, language) - def _parse_entry(self, entry: Dict) -> Optional[Dict]: + def _parse_entry(self, entry: dict) -> dict | None: """Parses a NewsCatcher news entry""" try: link = entry["link"] diff --git a/bertrend_apps/exploration/curebot/app.py b/bertrend_apps/exploration/curebot/app.py new file mode 100644 index 0000000..f7edb7a --- /dev/null +++ b/bertrend_apps/exploration/curebot/app.py @@ -0,0 +1,45 @@ +import streamlit as st + +from bertrend_apps.exploration.curebot.tabs import tab1, tab2, tab3 + +# Set wide layout +st.set_page_config( + page_title="Curebot - Exploration de sujets", + layout="wide", +) + +# Set app title +st.title("Curebot - Exploration de sujets") + +# Set sidebar +with st.sidebar: + with st.expander("Paramètres"): + st.checkbox( + "Utiliser les tags", + key="use_tags", + value=False, + help="Utiliser les tags Curebot pour orienter la recherche de sujets.", + ) + st.number_input( + "Nombre d'articles minimum par sujet", + 2, + 50, + 5, + key="min_articles_per_topic", + help="Permet d'influencer le nombre total de sujets trouvés par le modèle. Plus ce nombre est élevé, moins il y aura de sujets.", + ) + +# Create tabs +tab1_content, tab2_content, tab3_content = st.tabs( + ["Données", "Résultats", "Newsletter"] +) + +# Main tabs +with tab1_content: + tab1.show() + +with tab2_content: + tab2.show() + +with tab3_content: + tab3.show() diff --git a/bertrend_apps/exploration/curebot/app_utils.py b/bertrend_apps/exploration/curebot/app_utils.py new file mode 100644 index 0000000..7a3cbc2 --- /dev/null +++ b/bertrend_apps/exploration/curebot/app_utils.py @@ -0,0 +1,313 @@ +import tomllib +import json +from bertopic import BERTopic +import numpy as np +import pandas as pd +from pathlib import Path +import torch + +import streamlit as st +from streamlit.runtime.uploaded_file_manager import UploadedFile +from sentence_transformers import SentenceTransformer +import plotly.graph_objects as go +from urllib.parse import urlparse + +from bertrend.BERTopicModel import BERTopicModel +from bertrend.llm_utils.openai_client import OpenAI_Client +from bertrend_apps.exploration.curebot.prompts import ( + TOPIC_DESCRIPTION_SYSTEM_PROMPT, + TOPIC_SUMMARY_SYSTEM_PROMPT, +) + +# Get configuration file +CONFIG = tomllib.load(open(Path(__file__).parent / "config.toml", "rb")) + +# Set curebot column name +URL_COLUMN = CONFIG["data"]["url_column"] +TEXT_COLUMN = CONFIG["data"]["text_column"] +TITLE_COLUMN = CONFIG["data"]["title_column"] +SOURCE_COLUMN = CONFIG["data"]["source_column"] +TIMESTAMP_COLUMN = CONFIG["data"]["timestamp_column"] +TAGS_COLUMN = CONFIG["data"]["tags_column"] + +# Topics config +TOP_N_WORDS = CONFIG["topics"]["top_n_words"] + +# Newsletter +NEWSLETTER_TEMPLATE = CONFIG["newsletter"]["template"] + +# Load embdding model +EMBEDDING_MODEL = SentenceTransformer( + CONFIG["embedding"]["model_name"], + model_kwargs={"torch_dtype": torch.float16}, + trust_remote_code=True, +) + + +@st.cache_data(show_spinner=False) +def concat_data_from_files(files: list[UploadedFile]) -> pd.DataFrame: + """ + Concatenate data from multiple Excel files into a single DataFrame. + """ + df_list = [] + for file in files: + df_list.append(pd.read_excel(file)) + df = pd.concat(df_list, ignore_index=True) + + return df + + +@st.cache_data(show_spinner=False) +def chunk_df( + df: pd.DataFrame, chunk_size: int = 100, overlap: int = 20 +) -> pd.DataFrame: + """ + Split df texts into overlapping chunks while preserving other columns. + + Parameters: + ----------- + df : pandas.DataFrame + Input DataFrame containing text and other columns + chunk_size : int + Number of words in each chunk + overlap : int + Number of words to overlap between chunks + + Returns: + -------- + pandas.DataFrame + DataFrame with chunked texts, preserving other column values + """ + + # # Apply chunking to each row + # chunked_data = [] + # for _, row in df.iterrows(): + # chunked_data.extend(split_text_to_chunks(row, chunk_size, overlap)) + + # # Create new DataFrame from chunked data + # df_split = ( + # pd.DataFrame(chunked_data) + # .sort_values(by=TIMESTAMP_COLUMN) + # .reset_index(drop=True) + # ) + + return df.copy() + + +@st.cache_data(show_spinner=False) +def get_embeddings(texts: list[str]) -> np.ndarray: + """Get embeddings for a list of texts.""" + return EMBEDDING_MODEL.encode(texts) + + +@st.cache_data(show_spinner=False) +def fit_bertopic( + docs: list[str], + embeddings: np.ndarray, + min_articles_per_topic: int, + zeroshot_topic_list: list[str] | None = None, +) -> tuple[BERTopic, list[int]]: + """ + Fit BERTopic model on a list of documents and their embeddings. + """ + # Override default parameters + bertopic_config = {"hdbscan_model": {"min_cluster_size": min_articles_per_topic}} + # Initialize topic model + topic_model = BERTopicModel(config=bertopic_config) + + # Train topic model + topic_model_output = topic_model.fit( + docs=docs, + embedding_model=EMBEDDING_MODEL, + embeddings=embeddings, + zeroshot_topic_list=zeroshot_topic_list, + zeroshot_min_similarity=0.65, + ) + + return topic_model_output.topic_model, topic_model_output.topics + + +@st.cache_data(show_spinner=False) +def get_improved_topic_description( + df: pd.DataFrame, _topics_info: pd.DataFrame +) -> list[str]: + """Get improved topic description using LLM.""" + # Get llm client + llm_client = OpenAI_Client() + + # List of improved topics description + improved_descriptions = [] + + # Loop over topics + for topic_number in range(len(_topics_info)): + topic_df = df[df["topics"] == topic_number] + user_prompt = "\n\n".join( + topic_df.apply( + lambda row: f"Titre : {row[TITLE_COLUMN]}\nArticle : {row[TEXT_COLUMN][0:2000]}...", + axis=1, + ) + ) + response = llm_client.generate( + user_prompt=user_prompt, + system_prompt=TOPIC_DESCRIPTION_SYSTEM_PROMPT, + response_format={"type": "json_object"}, + ) + improved_descriptions.append(json.loads(response)["titre"]) + + return improved_descriptions + + +@st.cache_data(show_spinner=False) +def create_newsletter( + df: pd.DataFrame, + topics_info: pd.DataFrame, + nb_topics: int, + nb_articles_per_topic: int, +): + """ + Create newsletter dict from containing the newsletter content in the following format: + { + "title": "Newsletter", + "min_timestamp": "Monday 01 January 2025", + "max_timestamp": "Sunday 07 January 2025", + "topics": [ + { + "title": "Topic 1", + "keywords": "#keyword1 #keyword2 #keyword3", + "summary": "Summary of topic 1" + "articles": [ + { + "title": "Article 1", + "url": "https://www.article1.com", + "timestamp": "Monday 01 January 2023", + "source": "Source 1" + }, + ... + ] + }, + ... + ] + } + """ + # Create newsletter dict that stores the newsletter content + newsletter_dict = {"title": "Newsletter"} + + # Get min and max date of the articles in the dataframe + newsletter_dict["min_timestamp"] = ( + df[TIMESTAMP_COLUMN].min().strftime("%A %d %B %Y") + ) + newsletter_dict["max_timestamp"] = ( + df[TIMESTAMP_COLUMN].max().strftime("%A %d %B %Y") + ) + + newsletter_dict["topics"] = [] + for i in range(nb_topics): + # Dict to store topic info + topic_dict = {} + + # Get title and key words + topic_dict["title"] = topics_info.iloc[i]["llm_description"] + topic_dict["keywords"] = ( + "#" + " #".join(topics_info.iloc[i]["Representation"]).strip() + ) + + # Filter df to get articles for the topic + topic_df = df[df["topics"] == i] + + # Get first `newsletter_nb_articles_per_topic` articles for the topic + topic_df = topic_df.head(min(nb_articles_per_topic, len(topic_df))) + + # Get a summary of the topic + user_prompt = "\n\n".join( + topic_df.apply( + lambda row: f"Titre : {row[TITLE_COLUMN]}\nArticle : {row[TEXT_COLUMN][0:2000]}...", + axis=1, + ) + ) + llm_client = OpenAI_Client() + response = llm_client.generate( + user_prompt=user_prompt, + system_prompt=TOPIC_SUMMARY_SYSTEM_PROMPT, + response_format={"type": "json_object"}, + ) + topic_dict["summary"] = json.loads(response)["résumé"] + topic_dict["articles"] = [] + for _, row in topic_df.iterrows(): + article_dict = {} + article_dict["title"] = row[TITLE_COLUMN] + article_dict["url"] = row[URL_COLUMN] + article_dict["timestamp"] = row[TIMESTAMP_COLUMN].strftime("%A %d %B %Y") + article_dict["source"] = row[SOURCE_COLUMN] + topic_dict["articles"].append(article_dict) + newsletter_dict["topics"].append(topic_dict) + return newsletter_dict + + +def display_source_distribution( + representative_df: pd.DataFrame, selected_sources: list[str] +): + """Display the distribution of sources in a pie chart.""" + + source_counts = representative_df[URL_COLUMN].apply(get_website_name).value_counts() + + # Create a list to store the 'pull' values for each slice + pull = [] + + # Determine which slices should be pulled out + for source in source_counts.index: + if source in selected_sources and "All" not in selected_sources: + pull.append(0.2) + else: + pull.append(0) + + fig = go.Figure( + data=[ + go.Pie( + labels=source_counts.index, + values=source_counts.values, + pull=pull, + textposition="inside", + textinfo="percent+label", + hole=0.3, + ) + ] + ) + + fig.update_layout( + showlegend=False, height=600, width=500, margin=dict(t=0, b=0, l=0, r=0) + ) + + st.plotly_chart(fig, use_container_width=True) + + +def get_website_name(url): + """Extract website name from URL, handling None, NaN, and invalid URLs.""" + if pd.isna(url) or url is None or isinstance(url, float): + return "Unknown Source" + try: + return ( + urlparse(str(url)).netloc.replace("www.", "").split(".")[0] + or "Unknown Source" + ) + except: + return "Unknown Source" + + +def display_representative_documents(filtered_df: pd.DataFrame): + """Display representative documents for the selected topic.""" + with st.container(border=False, height=600): + for _, doc in filtered_df.iterrows(): + website_name = get_website_name(doc[URL_COLUMN]) + date = doc[TIMESTAMP_COLUMN].strftime("%A %d %b %Y %H:%M:%S") + snippet = ( + doc[TEXT_COLUMN][:200] + "..." + if len(doc[TEXT_COLUMN]) > 200 + else doc[TEXT_COLUMN] + ) + + content = f"""**{doc[TITLE_COLUMN]}**\n\n{date} | {'Unknown Source' if website_name == 'Unknown Source' else website_name}\n\n{snippet}""" + + if website_name != "Unknown Source": + st.link_button(content, doc[URL_COLUMN]) + else: + st.markdown(content) diff --git a/bertrend_apps/exploration/curebot/config.toml b/bertrend_apps/exploration/curebot/config.toml new file mode 100644 index 0000000..3cd91b7 --- /dev/null +++ b/bertrend_apps/exploration/curebot/config.toml @@ -0,0 +1,33 @@ +[embedding] +model_name = "dangvantuan/french-document-embedding" + +[data] +title_column = "Titre de la ressource" +text_column = "Contenu de la ressource" +url_column = "URL de la ressource" +source_column = "Domaine de la ressource" +timestamp_column = "Date de trouvaille" +tags_column = "Tags" + +[topics] +top_n_words = 5 + +[newsletter] +template = """ +# {{ title }} + +*Du {{ min_timestamp }} au {{ max_timestamp }}.* + +{% for topic in topics %} +## {{ topic.title }} + +**{{ topic.keywords }}** + +{{ topic.summary }} + +{% for article in topic.articles %} +* [{{ article.title }}]({{ article.url }}) - **{{ article.source }}** - *{{ article.timestamp }}* + +{% endfor %} +{% endfor %} +""" \ No newline at end of file diff --git a/bertrend_apps/exploration/curebot/veille_analyse.py b/bertrend_apps/exploration/curebot/old_veille_analyse.py similarity index 92% rename from bertrend_apps/exploration/curebot/veille_analyse.py rename to bertrend_apps/exploration/curebot/old_veille_analyse.py index 1544de5..a0bd21a 100644 --- a/bertrend_apps/exploration/curebot/veille_analyse.py +++ b/bertrend_apps/exploration/curebot/old_veille_analyse.py @@ -4,7 +4,6 @@ # This file is part of BERTrend. from pathlib import Path from tempfile import TemporaryDirectory -from typing import List import inspect import pandas as pd @@ -43,7 +42,7 @@ @st.cache_data -def parse_data_from_files(files: List[UploadedFile]) -> pd.DataFrame: +def parse_data_from_files(files: list[UploadedFile]) -> pd.DataFrame: """Read a list of Excel files and return a single dataframe containing the data""" dataframes = [] @@ -126,23 +125,6 @@ def train_model(): ) -def create_newsletter(): - with st.spinner("Création de la newsletters..."): - st.session_state["newsletters"], _, _ = generate_newsletter( - topic_model=st.session_state["topic_model"], - df=st.session_state["df"], - df_split=st.session_state["df_split"], - topics=st.session_state["topics"], - top_n_topics=st.session_state["newsletter_nb_topics"], - top_n_docs=st.session_state["newsletter_nb_docs"], - improve_topic_description=True, - summarizer_class=GPTSummarizer, - summary_mode="topic", - openai_model_name=st.session_state["openai_model_name"], - nb_sentences=st.session_state["nb_sentences"], - ) - - @st.experimental_dialog("Newsletter preview", width="large") def preview_newsletter(): content = md2html(st.session_state["final_newsletter"], css_style=css_style) diff --git a/bertrend_apps/exploration/curebot/prompts.py b/bertrend_apps/exploration/curebot/prompts.py new file mode 100644 index 0000000..bfc0f25 --- /dev/null +++ b/bertrend_apps/exploration/curebot/prompts.py @@ -0,0 +1,24 @@ +TOPIC_DESCRIPTION_SYSTEM_PROMPT = """ +Vous êtes expert en veille d'actualité et en anaylse thématique. +Dans le contexte d'une analyse d'articles de presse, plusieurs articles ont été regroupés en un même thème. +Votre tâche est de générer un titre pour ce thème sur la base des articles qui appartiennent à ce thème. +A partir de la liste des articles fournie (titre et contenu), rédigez un titre pour le thème. +Le titre doit être concis (maximum 5 mots) et représenter au mieux la spécificité du thème. +Répondez sous la forme d'un JSON suivant le format ci-dessous : +{ + "titre": "" +} +""" + +TOPIC_SUMMARY_SYSTEM_PROMPT = """ +Vous êtes expert en veille d'actualité et en anaylse thématique. +Dans le contexte d'une analyse d'articles de presse, plusieurs articles ont été regroupés en un même thème. +Votre tâche est de générer un résumé pour ce thème sur la base des articles qui appartiennent à ce thème. +A partir de la liste des articles fournie (titre et contenu), rédigez un résumé pour le thème. +Le résumé doit être concis (maximum 100 mots) et représenter au mieux la spécificité du thème. +Il doit pas commencer par "Les articles parlent de..." ou équivalent et doit être écrit dans un style journalistique. +Répondez sous la forme d'un JSON suivant le format ci-dessous : +{ + "résumé": "" +} +""" diff --git a/bertrend_apps/exploration/curebot/tabs/tab1.py b/bertrend_apps/exploration/curebot/tabs/tab1.py new file mode 100644 index 0000000..4862122 --- /dev/null +++ b/bertrend_apps/exploration/curebot/tabs/tab1.py @@ -0,0 +1,148 @@ +import pandas as pd +import streamlit as st + +from bertrend_apps.exploration.curebot.app_utils import ( + TAGS_COLUMN, + TEXT_COLUMN, + TIMESTAMP_COLUMN, + TITLE_COLUMN, + TOP_N_WORDS, + concat_data_from_files, + fit_bertopic, + get_embeddings, + get_improved_topic_description, +) + + +def show() -> None: + # Data uploading component + upload_data() + + # Load data into dataframe + if st.session_state.get("uploaded_files"): + try: + preprocess_data() + except Exception as e: + st.error( + f"Erreur lors du chargement des données. Vérifiez que vos données respectent le format Curebot attendu." + ) + + # If data is loaded + if "df" in st.session_state: + # Show data + with st.expander("Voir les données", expanded=False): + st.write(st.session_state["df"]) + + # Button to train model + if st.button("Détecter les sujets", type="primary"): + with st.spinner("Détection des sujets..."): + # Train model + train_model() + + # If topic model is trained, update df with topics and llm description + if "topic_model" in st.session_state: + st.session_state["df"]["topics"] = st.session_state["topics"] + with st.spinner("Génération des titres des sujets..."): + st.session_state["topics_info"]["llm_description"] = ( + get_improved_topic_description( + st.session_state["df"], st.session_state["topics_info"] + ) + ) + st.success("Sujets détectés, voir l'onglet résultats.") + + +def upload_data() -> None: + """ + Data uploading component for Curebot format. + Sets in sessions_state: + - "uploaded_files": list of uploaded files + """ + # Excel files input + st.session_state["uploaded_files"] = st.file_uploader( + "Fichiers Excel au format rapport Curebot `.xlsx`", + accept_multiple_files=True, + help="Glisser/déposer dans cette zone les exports Curebot au format Excel", + ) + + +def preprocess_data() -> None: + """ + Preprocess data from uploaded files. + Sets in session_state: + - "df": dataframe with all data + """ + # Concatenate uploaded Excel files into a single dataframe + df = concat_data_from_files(st.session_state["uploaded_files"]) + + # Remove duplicates based on title and text columns + df = df.drop_duplicates(subset=[TITLE_COLUMN, TEXT_COLUMN]).reset_index(drop=True) + + # Remove rows where text is empty + df = df[df[TEXT_COLUMN].notna()].reset_index(drop=True) + + # Sort df by date + df[TIMESTAMP_COLUMN] = pd.to_datetime(df[TIMESTAMP_COLUMN]) + df = df.sort_values(by=TIMESTAMP_COLUMN, ascending=False).reset_index(drop=True) + + st.session_state["df"] = df + + +def train_model() -> None: + """ + Train a BERTopic model based on provided data. + Sets in session_state: + - "embeddings": embeddings of the dataset + - "topic_model": trained BERTopic model + - "topics": topics extracted by the model + - "topics_info": information about the topics + """ + # Get texts list and embeddings + texts_list = st.session_state["df"][TEXT_COLUMN].tolist() + embeddings = get_embeddings(texts_list) + + # If use_tags is True, get tags from dataframe + if st.session_state["use_tags"]: + # Convert tags to string + st.session_state["df"][TAGS_COLUMN] = st.session_state["df"][ + TAGS_COLUMN + ].astype(str) + + # Get zeroshot_topic_list from tags + zeroshot_topic_list = ( + st.session_state["df"][TAGS_COLUMN] + .fillna("") + .str.findall(r"#\w+") + .explode() + .unique() + ) + + # Remove # and _ from tags and convert to string + zeroshot_topic_list = [ + str(tag).replace("#", "").replace("_", " ") + for tag in zeroshot_topic_list + if tag + ] + # Else, set zeroshot_topic_list to None + else: + zeroshot_topic_list = None + + # Train topic model + bertopic, topics = fit_bertopic( + texts_list, + embeddings, + st.session_state["min_articles_per_topic"], + zeroshot_topic_list=zeroshot_topic_list, + ) + + # Set session_state + st.session_state["topic_model"] = bertopic + st.session_state["topics"] = topics + + topic_info = bertopic.get_topic_info() + topic_info["Representation"] = topic_info["Representation"].apply( + lambda x: x[:TOP_N_WORDS] + ) + + st.session_state["topics_info"] = topic_info[ + topic_info["Topic"] != -1 + ] # exclude -1 topic from topic list diff --git a/bertrend_apps/exploration/curebot/tabs/tab2.py b/bertrend_apps/exploration/curebot/tabs/tab2.py new file mode 100644 index 0000000..8145d2b --- /dev/null +++ b/bertrend_apps/exploration/curebot/tabs/tab2.py @@ -0,0 +1,95 @@ +import streamlit as st + +from bertrend_apps.exploration.curebot.app_utils import ( + URL_COLUMN, + display_representative_documents, + display_source_distribution, + get_website_name, +) + + +def show() -> None: + # Check if a model is trained + if "topic_model" not in st.session_state: + st.warning("Veuillez ajouter des données et lancer la détection des sujets.") + else: + # Show sidebar with topic list + show_topic_list() + + # Display selected topic info + display_topic_info(st.session_state.get("selected_topic_number", 0)) + + +def show_topic_list(): + """ + Show topic list in the sidebar. + Each topic is a button that, when clicked, sets session_state: + - "selected_topic_number": topic number, this topic is displayed in results tab. + """ + with st.sidebar: + with st.expander("Sujets détectés", expanded=True): + # Topics list + for _, topic in st.session_state["topics_info"].iterrows(): + # Display button for each topic + topic_number = topic["Topic"] + button_title = topic["llm_description"] + st.button( + str(topic_number + 1) + " - " + button_title, + use_container_width=True, + on_click=set_topic_selection, + args=(topic_number,), + ) + + +def set_topic_selection(topic_number: int): + """Set "selected_topic_number" in the session state.""" + st.session_state["selected_topic_number"] = topic_number + + +def display_topic_info(topic_number: int): + """Display "selected_topic_number" associated topic information.""" + # Filter df to get only the selected topic + selected_topic_df = st.session_state["df"][ + st.session_state["df"]["topics"] == topic_number + ] + + # Get topic info + docs_count = len(selected_topic_df) + key_words = st.session_state["topics_info"].iloc[topic_number]["Representation"] + llm_description = st.session_state["topics_info"].iloc[topic_number][ + "llm_description" + ] + + # Display topic info + st.write(f"# {llm_description}") + st.write(f"### {docs_count} documents") + st.markdown(f"### #{' #'.join(key_words)}") + + # Get unique sources + sources = selected_topic_df[URL_COLUMN].apply(get_website_name).unique() + + # Multi-select for sources + selected_sources = st.multiselect( + "Sélectionner les sources à afficher :", + options=["All"] + sorted(list(sources)), + default=["All"], + ) + + # Create two columns + col21, col22 = st.columns([0.3, 0.7]) + + with col21: + # Pass the full representative_df to display_source_distribution + display_source_distribution(selected_topic_df, selected_sources) + + with col22: + # Filter the dataframe only for document display + if "All" not in selected_sources: + filtered_df = selected_topic_df[ + selected_topic_df[URL_COLUMN] + .apply(get_website_name) + .isin(selected_sources) + ] + else: + filtered_df = selected_topic_df + display_representative_documents(filtered_df) diff --git a/bertrend_apps/exploration/curebot/tabs/tab3.py b/bertrend_apps/exploration/curebot/tabs/tab3.py new file mode 100644 index 0000000..971ea24 --- /dev/null +++ b/bertrend_apps/exploration/curebot/tabs/tab3.py @@ -0,0 +1,125 @@ +import locale +from pathlib import Path +import jinja2 +import streamlit as st +from bertrend.llm_utils.newsletter_features import md2html +from bertrend_apps.exploration.curebot.app_utils import ( + NEWSLETTER_TEMPLATE, + create_newsletter, +) + +# Set french locale +locale.setlocale(locale.LC_ALL, "fr_FR.UTF-8") + +CSS_STYLE = ( + Path(__file__).parent.parent.parent.parent.parent + / "bertrend/llm_utils/newsletter.css" +) + + +def show() -> None: + if "topic_model" not in st.session_state: + st.warning("Veuillez ajouter des données et lancer la détection des sujets.") + else: + # Show newsletter parameters + newsletter_parameters() + + # Columns for buttons + col1, col2, col3 = st.columns([1, 1, 1]) + + # Show newsletter creation button + with col1: + if st.button("Créer la newsletter"): + with st.spinner("Création de la newsletter..."): + get_newsletter() + + # Show newsletter edition button + with col2: + st.button( + "Éditer", + on_click=edit_newsletter, + disabled=not "newsletter_text" in st.session_state, + ) + + # Show newsletter download button + with col3: + # If newsletter created, set it as data to download + if "newsletter_text" in st.session_state: + data = md2html(st.session_state["newsletter_text"], css_style=CSS_STYLE) + # Else, set data to empty string + else: + data = "" + + # Download button + st.download_button( + "Télécharger", + file_name="newsletters.html", + mime="text/html", + data=data, + disabled=not "newsletter_text" in st.session_state, + ) + + # Show newsletter + if "newsletter_text" in st.session_state: + st.write(st.session_state["newsletter_text"], unsafe_allow_html=True) + + +def newsletter_parameters() -> None: + """ + Show newsletter parameters: + - Number of topics to include in the newsletter + - Number of articles per topic to include in the newsletter + """ + col31, col32 = st.columns([1, 1]) + with col31: + # Select number of topics to include in the newsletter + st.slider( + "Nombre de sujets", + 1, + len(st.session_state["topics_info"]), + value=min(len(st.session_state["topics"]), 3), + key="newsletter_nb_topics", + ) + with col32: + # Select number of articles per topic to include in the newsletter + st.slider( + "Nombre d'articles par sujet", + 1, + 10, + value=4, + key="newsletter_nb_articles_per_topic", + ) + + +def get_newsletter() -> None: + """ + Create a newsletter based on the selected topics and articles. + Sets in sessions_state: + - newsletter_dict: a dictionary containing the newsletter data + - newsletter_text: the newsletter text + """ + # Newsletter dict + st.session_state["newsletter_dict"] = create_newsletter( + st.session_state["df"], + st.session_state["topics_info"], + st.session_state["newsletter_nb_topics"], + st.session_state["newsletter_nb_articles_per_topic"], + ) + + # Newsletter text + template = jinja2.Template(NEWSLETTER_TEMPLATE) + st.session_state["newsletter_text"] = template.render( + st.session_state["newsletter_dict"] + ) + + +@st.dialog("Éditer la newsletter", width="large") +def edit_newsletter() -> None: + edited_newsltter = st.text_area( + "", + value=st.session_state["newsletter_text"], + height=500, + ) + st.session_state["newsletter_text"] = edited_newsltter + if st.button("Enregistrer", type="primary"): + st.rerun() diff --git a/bertrend_apps/newsletters/__main__.py b/bertrend_apps/newsletters/__main__.py index d3a60b5..bba8511 100644 --- a/bertrend_apps/newsletters/__main__.py +++ b/bertrend_apps/newsletters/__main__.py @@ -6,7 +6,6 @@ import glob import os from pydoc import locate -from typing import List, Tuple import pandas as pd import typer @@ -205,8 +204,8 @@ def _train_topic_model( dataset: pd.DataFrame, embedding_model: str, embeddings: ndarray, - ) -> Tuple[List, BERTopic]: - topic_model = BERTopicModel.from_config(config_file) + ) -> tuple[list, BERTopic]: + topic_model = BERTopicModel(config_file) output = topic_model.fit( docs=dataset[TEXT_COLUMN], embedding_model=embedding_model, diff --git a/bertrend_apps/prospective_demo/.streamlit/secrets.toml b/bertrend_apps/prospective_demo/.streamlit/secrets.toml new file mode 100644 index 0000000..af460ea --- /dev/null +++ b/bertrend_apps/prospective_demo/.streamlit/secrets.toml @@ -0,0 +1,8 @@ +# .streamlit/secrets.toml + +[passwords] +# Follow the rule: username = "password" +jerome = "jerome" +guillaume = "guillaume" +dsia = "dsia" +nemo = "nemo" diff --git a/bertrend_apps/prospective_demo/__init__.py b/bertrend_apps/prospective_demo/__init__.py new file mode 100644 index 0000000..797ff65 --- /dev/null +++ b/bertrend_apps/prospective_demo/__init__.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +from pathlib import Path + +from bertrend import MODELS_DIR, FEED_BASE_PATH, CONFIG_PATH + +# Config path for users +CONFIG_FEEDS_BASE_PATH = CONFIG_PATH / "users" +CONFIG_FEEDS_BASE_PATH.mkdir(parents=True, exist_ok=True) + +# Models config path +BASE_MODELS_DIR = MODELS_DIR / "users" +INTERPRETATION_PATH = "interpretation" + +# some identifiers +NOISE = "noise" +WEAK_SIGNALS = "weak_signals" +STRONG_SIGNALS = "strong_signals" +LLM_TOPIC_DESCRIPTION_COLUMN = "LLM Description" + +# Models & analysis +DEFAULT_GRANULARITY = 2 +DEFAULT_WINDOW_SIZE = 7 + +DEFAULT_ANALYSIS_CFG = { + "model_config": { + "granularity": DEFAULT_GRANULARITY, + "window_size": DEFAULT_WINDOW_SIZE, + "language": "French", + }, + "analysis_config": { + "topic_evolution": True, + "evolution_scenarios": True, + "multifactorial_analysis": True, + }, +} + + +def get_user_feed_path(user_name: str, feed_id: str) -> Path: + feed_path = CONFIG_FEEDS_BASE_PATH / user_name / f"{feed_id}_feed.toml" + return feed_path + + +def get_user_models_path(user_name: str, model_id: str) -> Path: + # Path to previously saved models for those data and this user + models_path = BASE_MODELS_DIR / user_name / model_id + models_path.mkdir(parents=True, exist_ok=True) + return models_path + + +def get_model_cfg_path(user_name: str, model_id: str) -> Path: + model_cfg_path = CONFIG_FEEDS_BASE_PATH / user_name / f"{model_id}_analysis.toml" + return model_cfg_path diff --git a/bertrend_apps/prospective_demo/app.py b/bertrend_apps/prospective_demo/app.py new file mode 100644 index 0000000..0902db5 --- /dev/null +++ b/bertrend_apps/prospective_demo/app.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +from typing import Literal + +import streamlit as st + +from bertrend.demos.demos_utils import is_admin_mode +from bertrend.demos.demos_utils.icons import ( + SETTINGS_ICON, + ANALYSIS_ICON, + NEWSLETTER_ICON, + SERVER_STORAGE_ICON, + TOPIC_ICON, + TREND_ICON, + MODELS_ICON, +) +from bertrend.demos.demos_utils.parameters_component import ( + display_bertopic_hyperparameters, + display_bertrend_hyperparameters, +) +from bertrend.demos.demos_utils.state_utils import SessionStateManager +from bertrend_apps.prospective_demo.authentication import check_password +from bertrend_apps.prospective_demo.dashboard_analysis import dashboard_analysis +from bertrend_apps.prospective_demo.feeds_config import configure_information_sources +from bertrend_apps.prospective_demo.feeds_data import display_data_status +from bertrend_apps.prospective_demo.models_info import models_monitoring +from bertrend_apps.prospective_demo.report_generation import reporting +from bertrend_apps.prospective_demo.signal_analysis import signal_analysis + +# UI Settings +# PAGE_TITLE = "BERTrend - Prospective Analysis demo" +PAGE_TITLE = "BERTrend - Démo Veille & Analyse" +LAYOUT: Literal["centered", "wide"] = "wide" + +# TODO: reactivate password +# AUTHENTIFICATION = True +AUTHENTIFICATION = False + + +def main(): + """Main page""" + st.set_page_config( + page_title=PAGE_TITLE, + layout=LAYOUT, + initial_sidebar_state="expanded" if is_admin_mode() else "collapsed", + page_icon=":part_alternation_mark:", + ) + + st.title(":part_alternation_mark: " + PAGE_TITLE) + + if AUTHENTIFICATION: + username = check_password() + if not username: + st.stop() + else: + SessionStateManager.set("username", username) + else: + SessionStateManager.get_or_set( + "username", "nemo" + ) # if username is not set or authentication deactivated + + # Sidebar + with st.sidebar: + st.header(SETTINGS_ICON + " Settings and Controls") + + # Main content + tab1, tab2, tab3, tab4, tab5 = st.tabs( + [ + NEWSLETTER_ICON + " Veilles", + MODELS_ICON + " Modèles", + TREND_ICON + " Tendances", + ANALYSIS_ICON + " Analyses", + NEWSLETTER_ICON + " Génération de rapports", + ] + ) + + with tab1: + with st.expander( + "Configuration des flux de données", expanded=True, icon=SETTINGS_ICON + ): + configure_information_sources() + + with st.expander( + "Etat de collecte des données", expanded=False, icon=SERVER_STORAGE_ICON + ): + display_data_status() + with tab2: + with st.expander( + "Statut des modèles par veille", expanded=True, icon=MODELS_ICON + ): + models_monitoring() + + with tab3: + signal_analysis() + + with tab4: + dashboard_analysis() + + with tab5: + reporting() + + +if __name__ == "__main__": + main() diff --git a/bertrend_apps/prospective_demo/authentication.py b/bertrend_apps/prospective_demo/authentication.py new file mode 100644 index 0000000..94e3967 --- /dev/null +++ b/bertrend_apps/prospective_demo/authentication.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +import hmac + +import streamlit as st + +from bertrend.demos.demos_utils.icons import UNHAPPY_ICON + + +def login_form(): + """Form with widgets to collect user information""" + with st.form("Credentials"): + st.text_input("Username", key="username") + st.text_input("Password", type="password", key="password") + st.form_submit_button("Log in", on_click=password_entered) + + +def password_entered(): + """Checks whether a password entered by the user is correct.""" + if st.session_state["username"] in st.secrets["passwords"] and hmac.compare_digest( + st.session_state["password"], + st.secrets.passwords[st.session_state["username"]], + ): + st.session_state["password_correct"] = True + del st.session_state["password"] # Don't store the username or password. + # del st.session_state["username"] + else: + st.session_state["password_correct"] = False + + +def check_password() -> str | None: + """Returns the user name if the user had a correct password, otherwise None.""" + + # Return True if the username + password is validated. + if st.session_state.get("password_correct", False): + return st.session_state["username"] + + # Show inputs for username + password. + login_form() + if "password_correct" in st.session_state: + st.error(f"{UNHAPPY_ICON} User not known or password incorrect") + return None diff --git a/bertrend_apps/prospective_demo/dashboard_analysis.py b/bertrend_apps/prospective_demo/dashboard_analysis.py new file mode 100644 index 0000000..30c5b4f --- /dev/null +++ b/bertrend_apps/prospective_demo/dashboard_analysis.py @@ -0,0 +1,161 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +from pathlib import Path + +import pandas as pd +import streamlit as st + +from bertrend.demos.demos_utils.icons import WARNING_ICON +from bertrend.demos.weak_signals.visualizations_utils import ( + display_signal_categories_df, +) +from bertrend_apps.prospective_demo import ( + INTERPRETATION_PATH, + get_user_models_path, + WEAK_SIGNALS, + STRONG_SIGNALS, + NOISE, + LLM_TOPIC_DESCRIPTION_COLUMN, +) +from bertrend_apps.prospective_demo.models_info import get_models_info + +COLS_RATIO = [2 / 7, 5 / 7] + + +@st.fragment() +def dashboard_analysis(): + """Dashboard to analyze information monitoring results""" + st.session_state.signal_interpretations = {} + + col1, col2 = st.columns(COLS_RATIO) + with col1: + model_id = st.selectbox( + "Sélection de la veille", options=sorted(st.session_state.user_feeds.keys()) + ) + with col2: + list_models = get_models_info(model_id) + if not list_models: + st.warning(f"{WARNING_ICON} Pas de modèle disponible") + st.stop() + elif len(list_models) < 2: + st.warning( + f"{WARNING_ICON} 2 modèles minimum pour analyser les tendances !" + ) + st.stop() + reference_ts = st.select_slider( + "Date d'analyse", + options=list_models, + value=list_models[-1], + format_func=lambda ts: ts.strftime("%d/%m/%Y"), + help="Sélection de la date d'analyse parmi celles disponibles", + ) + + # LLM-based interpretation + model_interpretation_path = ( + get_user_models_path(user_name=st.session_state.username, model_id=model_id) + / INTERPRETATION_PATH + / reference_ts.strftime("%Y-%m-%d") + ) + + dfs_topics = {} + for df_id in [NOISE, WEAK_SIGNALS, STRONG_SIGNALS]: + df_path = model_interpretation_path / f"{df_id}.parquet" + dfs_topics[df_id] = ( + pd.read_parquet(df_path) if df_path.exists() else pd.DataFrame() + ) + + cols = st.columns(COLS_RATIO) + with cols[0]: + # Display data frames + columns = [ + "Topic", + LLM_TOPIC_DESCRIPTION_COLUMN, + "Representation", + "Latest_Popularity", + "Docs_Count", + "Paragraphs_Count", + "Latest_Timestamp", + "Documents", + "Sources", + "Source_Diversity", + ] + + display_signal_categories_df( + dfs_topics[NOISE], + dfs_topics[WEAK_SIGNALS], + dfs_topics[STRONG_SIGNALS], + reference_ts, + columns=columns, + ) + + with cols[1]: + # Detailed analysis + st.subheader("Analyse détaillée par sujet") + display_detailed_analysis(model_id, model_interpretation_path, dfs_topics) + + +@st.fragment() +def display_detailed_analysis( + model_id: str, model_interpretation_path: Path, dfs_topics: dict[str, pd.DataFrame] +): + # Retrieve previously computed interpretation + interpretations = {} + for df_id, df in dfs_topics.items(): + if not df.empty: + interpretation_file_path = ( + model_interpretation_path / f"{df_id}_interpretation.jsonl" + ) + interpretations[df_id] = ( + pd.merge( + pd.read_json(interpretation_file_path, lines=True), + df, + how="left", + left_on="topic", + right_on="Topic", + ) + if interpretation_file_path.exists() + else {} + ) + + signal_topics = {WEAK_SIGNALS: [], STRONG_SIGNALS: []} + if WEAK_SIGNALS in interpretations: + signal_topics[WEAK_SIGNALS] = list(interpretations[WEAK_SIGNALS]["topic"]) + if STRONG_SIGNALS in interpretations: + signal_topics[STRONG_SIGNALS] = list(interpretations[STRONG_SIGNALS]["topic"]) + signal_list = signal_topics[WEAK_SIGNALS] + signal_topics[STRONG_SIGNALS] + selected_signal = st.selectbox( + label="Sélection du sujet", + label_visibility="hidden", + options=signal_list, + format_func=lambda signal_id: f"[Sujet {'émergent' if signal_id in signal_topics[WEAK_SIGNALS] else 'fort'} " + f"{signal_id}]: {get_row(signal_id, interpretations[WEAK_SIGNALS] if signal_id in signal_topics[WEAK_SIGNALS] else interpretations[STRONG_SIGNALS])[LLM_TOPIC_DESCRIPTION_COLUMN]['title']}", + ) + # Summary of the topic + desc = get_row( + selected_signal, + ( + interpretations[WEAK_SIGNALS] + if selected_signal in signal_topics[WEAK_SIGNALS] + else interpretations[STRONG_SIGNALS] + ), + ) + if selected_signal in list(signal_topics[WEAK_SIGNALS]): + color = "orange" + else: + color = "green" + st.subheader(f":{color}[**{desc[LLM_TOPIC_DESCRIPTION_COLUMN]['title']}**]") + st.write(desc[LLM_TOPIC_DESCRIPTION_COLUMN]["description"]) + # Detailed description + st.html(desc["analysis"]) + + st.session_state.signal_interpretations[model_id] = interpretations + + +def get_row(signal_id: int, df: pd.DataFrame) -> str: + filtered_df = df[df["topic"] == signal_id] + if not filtered_df.empty: + return filtered_df.iloc[0] # Return the Series (row) + else: + st.warning(f"No data found for signal ID: {signal_id}") diff --git a/bertrend_apps/prospective_demo/feeds_common.py b/bertrend_apps/prospective_demo/feeds_common.py new file mode 100644 index 0000000..f4f8b28 --- /dev/null +++ b/bertrend_apps/prospective_demo/feeds_common.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +from pathlib import Path + +from loguru import logger + +from bertrend import FEED_BASE_PATH, load_toml_config +from bertrend_apps.prospective_demo import CONFIG_FEEDS_BASE_PATH + + +def read_user_feeds(username: str) -> tuple[dict[str, dict], dict[str, Path]]: + """Read user feed config files""" + user_feed_dir = CONFIG_FEEDS_BASE_PATH / username + user_feed_dir.mkdir(parents=True, exist_ok=True) + logger.debug(f"Reading user feeds from: {user_feed_dir}") + matching_files = user_feed_dir.rglob("*_feed.toml") + + user_feeds = {} + feed_files = {} + for f in matching_files: + feed_id = f.name.split("_feed.toml")[0] + user_feeds[feed_id] = load_toml_config(f) + feed_files[feed_id] = f + + return user_feeds, feed_files + + +def get_all_files_for_feed(user_feeds: dict[str, dict], feed_id: str) -> list[Path]: + """Returns the paths of all files associated to a feed for the current user.""" + feed_base_dir = user_feeds[feed_id]["data-feed"]["feed_dir_path"] + list_all_files = list( + Path(FEED_BASE_PATH, feed_base_dir).glob( + f"*{user_feeds[feed_id]['data-feed'].get('id')}*.jsonl*" + ) + ) + return list_all_files diff --git a/bertrend_apps/prospective_demo/feeds_config.py b/bertrend_apps/prospective_demo/feeds_config.py new file mode 100644 index 0000000..73b3d13 --- /dev/null +++ b/bertrend_apps/prospective_demo/feeds_config.py @@ -0,0 +1,261 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +import re +import time +from pathlib import Path + +import pandas as pd +import streamlit as st +import toml +from loguru import logger + +from bertrend.config.parameters import LANGUAGES +from bertrend.demos.demos_utils.icons import ( + INFO_ICON, + ERROR_ICON, + ADD_ICON, + EDIT_ICON, + DELETE_ICON, + WARNING_ICON, + TOGGLE_ON_ICON, + TOGGLE_OFF_ICON, +) +from bertrend_apps.common.crontab_utils import ( + get_understandable_cron_description, + check_if_scrapping_active_for_user, + remove_scrapping_for_user, + schedule_scrapping, +) +from bertrend_apps.data_provider import URL_PATTERN +from bertrend_apps.prospective_demo.feeds_common import ( + read_user_feeds, +) +from bertrend_apps.prospective_demo import CONFIG_FEEDS_BASE_PATH +from bertrend_apps.prospective_demo.models_info import ( + remove_scheduled_training_for_user, +) +from bertrend_apps.prospective_demo.streamlit_utils import clickable_df + +# Default feed configs +DEFAULT_GNEWS_CRONTAB_EXPRESSION = "1 0 * * 1" +DEFAULT_CUREBOT_CRONTAB_EXPRESSION = "42 0,6,12,18 * * *" # 4 times a day +DEFAULT_MAX_RESULTS = 20 +DEFAULT_NUMBER_OF_DAYS = 14 +FEED_SOURCES = ["google", "curebot"] +TRANSLATION = {"English": "Anglais", "French": "Français"} + + +@st.dialog("Configuration d'un nouveau flux de données") +def edit_feed_monitoring(config: dict | None = None): + """Create or update a feed monitoring configuration.""" + chosen_id = st.text_input( + "ID :red[*]", + help="Identifiant du flux de données", + value=None if not config else config["id"], + ) + + provider = st.segmented_control( + "Source", + selection_mode="single", + options=FEED_SOURCES, + default=FEED_SOURCES[0] if not config else config["provider"], + help="Sélection de la source de données", + ) + if provider == "google": + query = st.text_input( + "Requête :red[*]", + value="" if not config else config["query"], + help="Saisir ici la requête qui sera faite sur Google News", + ) + language = st.segmented_control( + "Langue", + selection_mode="single", + options=LANGUAGES, + default=LANGUAGES[0], + format_func=lambda lang: TRANSLATION[lang], + help="Choix de la langue", + ) + if "update_frequency" not in st.session_state: + st.session_state.update_frequency = ( + DEFAULT_GNEWS_CRONTAB_EXPRESSION + if not config + else config["update_frequency"] + ) + new_freq = st.text_input( + f"Fréquence d'exécution", + value=st.session_state.update_frequency, + help=f"Fréquence de collecte des données", + ) + st.session_state.update_frequency = new_freq + st.write(display_crontab_description(st.session_state.update_frequency)) + + elif provider == "curebot": + query = st.text_input( + "ATOM feed :red[*]", + value="" if not config else config["query"], + help="URL du flux de données Curebot", + ) + + try: + get_understandable_cron_description(st.session_state.update_frequency) + valid_cron = True + except: + valid_cron = False + + if st.button( + "OK", + disabled=not chosen_id + or not query + or (query and provider == "curebot" and not re.match(URL_PATTERN, query)), + ): + if not config: + config = {} + config["id"] = "feed_" + chosen_id + config["feed_dir_path"] = ( + "users/" + st.session_state.username + "/feed_" + chosen_id + ) + config["query"] = query + config["provider"] = provider + if not config.get("max_results"): + config["max_results"] = DEFAULT_MAX_RESULTS + if not config.get("number_of_days"): + config["number_of_days"] = DEFAULT_NUMBER_OF_DAYS + if provider == "google": + config["language"] = "fr" if language == "French" else "en" + config["update_frequency"] = ( + st.session_state.update_frequency + if valid_cron + else DEFAULT_GNEWS_CRONTAB_EXPRESSION + ) + elif provider == "curebot": + config["language"] = "fr" + config["update_frequency"] = DEFAULT_CUREBOT_CRONTAB_EXPRESSION + + if "update_frequency" in st.session_state: + del st.session_state["update_frequency"] # to avoid memory effect + + # Remove prevous crontab if any + remove_scrapping_for_user(feed_id=chosen_id, user=st.session_state.username) + + # Save feed config and update crontab + save_feed_config(chosen_id, config) + + +def save_feed_config(chosen_id, feed_config: dict): + """Save the feed configuration to disk as a TOML file.""" + feed_path = ( + CONFIG_FEEDS_BASE_PATH / st.session_state.username / f"{chosen_id}_feed.toml" + ) + # Save the dictionary to a TOML file + with open(feed_path, "w") as toml_file: + toml.dump({"data-feed": feed_config}, toml_file) + logger.debug(f"Saved feed config {feed_config} to {feed_path}") + schedule_scrapping(feed_path, user=st.session_state.username) + st.rerun() + + +def display_crontab_description(crontab_expr: str) -> str: + try: + return f":blue[{INFO_ICON} {get_understandable_cron_description(crontab_expr)}]" + except Exception: + return f":red[{ERROR_ICON} Expression mal écrite !]" + + +def configure_information_sources(): + """Configure Information Sources.""" + # if "user_feeds" not in st.session_state: + st.session_state.user_feeds, st.session_state.feed_files = read_user_feeds( + st.session_state.username + ) + + displayed_list = [] + for k, v in st.session_state.user_feeds.items(): + displayed_list.append( + { + "id": k, + "provider": v["data-feed"]["provider"], + "query": v["data-feed"]["query"], + "language": v["data-feed"]["language"], + "update_frequency": v["data-feed"]["update_frequency"], + } + ) + df = pd.DataFrame(displayed_list) + if not df.empty: + df = df.sort_values(by="id", inplace=False).reset_index(drop=True) + + if st.button(f":green[{ADD_ICON}]", type="tertiary", help="Nouveau flux de veille"): + edit_feed_monitoring() + + clickable_df_buttons = [ + (EDIT_ICON, edit_feed_monitoring, "secondary"), + (lambda x: toggle_icon(df, x), toggle_feed, "secondary"), + (DELETE_ICON, handle_delete, "primary"), + ] + clickable_df(df, clickable_df_buttons) + + +def toggle_icon(df: pd.DataFrame, index: int) -> str: + """Switch the toggle icon depending on the statis of the scrapping feed in the crontab""" + feed_id = df["id"][index] + return ( + f":green[{TOGGLE_ON_ICON}]" + if check_if_scrapping_active_for_user( + feed_id=feed_id, user=st.session_state.username + ) + else f":red[{TOGGLE_OFF_ICON}]" + ) + + +def toggle_feed(cfg: dict): + """Activate / deactivate the feed from the crontab""" + feed_id = cfg["id"] + if check_if_scrapping_active_for_user( + feed_id=feed_id, user=st.session_state.username + ): + if remove_scrapping_for_user(feed_id=feed_id, user=st.session_state.username): + st.toast(f"Le flux **{feed_id}** est déactivé !", icon=INFO_ICON) + logger.info(f"Flux {feed_id} désactivé !") + else: + schedule_scrapping( + st.session_state.feed_files[feed_id], user=st.session_state.username + ) + st.toast(f"Le flux **{feed_id}** est activé !", icon=WARNING_ICON) + logger.info(f"Flux {feed_id} activé !") + time.sleep(0.2) + st.rerun() + + +def delete_feed_config(feed_id: str): + # remove config file + file_path: Path = st.session_state.feed_files[feed_id] + try: + file_path.unlink() + logger.debug(f"Feed file {file_path} has been removed.") + except Exception as e: + logger.error(f"An error occurred: {e}") + + +@st.dialog("Confirmation") +def handle_delete(row_dict: dict): + """Function to handle remove click events""" + feed_id = row_dict["id"] + st.write( + f":orange[{WARNING_ICON}] Voulez-vous vraiment supprimer le flux de veille **{feed_id}** ?" + ) + col1, col2, _ = st.columns([2, 2, 8]) + with col1: + if st.button("Oui", type="primary"): + remove_scrapping_for_user(feed_id=feed_id, user=st.session_state.username) + delete_feed_config(feed_id) + logger.info(f"Flux {feed_id} supprimé !") + # Remove from crontab associated training + remove_scheduled_training_for_user( + model_id=feed_id, user=st.session_state.username + ) + time.sleep(0.2) + st.rerun() + with col2: + if st.button("Non"): + st.rerun() diff --git a/bertrend_apps/prospective_demo/feeds_data.py b/bertrend_apps/prospective_demo/feeds_data.py new file mode 100644 index 0000000..4104953 --- /dev/null +++ b/bertrend_apps/prospective_demo/feeds_data.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +import datetime +from pathlib import Path + +import pandas as pd +import streamlit as st + +from bertrend import FEED_BASE_PATH +from bertrend.utils.data_loading import ( + load_data, + TIMESTAMP_COLUMN, + TITLE_COLUMN, + URL_COLUMN, + TEXT_COLUMN, +) +from bertrend_apps.prospective_demo.feeds_common import get_all_files_for_feed + + +def display_data_status(): + if not st.session_state.user_feeds: + return + + col1, col2 = st.columns(2) + with col1: + st.selectbox( + "Sélection de la veille", + options=sorted(st.session_state.user_feeds.keys()), + key="id_data", + ) + + with col2: + if "data_time_window" not in st.session_state: + st.session_state.data_time_window = 7 + st.slider( + "Fenêtre temporelle (jours)", + min_value=1, + max_value=60, + step=1, + key="data_time_window", + ) + + display_data_info_for_feed(st.session_state.id_data) + + +def display_data_info_for_feed(feed_id: str): + all_files = get_all_files_for_feed(st.session_state.user_feeds, feed_id) + df = get_all_data(files=all_files) + + if df.empty: + df_filtered = pd.DataFrame() + else: + df = df[ + [TITLE_COLUMN, URL_COLUMN, TEXT_COLUMN, TIMESTAMP_COLUMN] + ] # filter useful columns + + cutoff_date = datetime.datetime.now() - datetime.timedelta( + days=st.session_state.data_time_window + ) + df_filtered = df[df[TIMESTAMP_COLUMN] >= cutoff_date] + + stats = { + "ID": feed_id, + "# Fichiers": len(all_files), + "Date début": df[TIMESTAMP_COLUMN].min() if not df.empty else None, + "Date fin": df[TIMESTAMP_COLUMN].max() if not df.empty else None, + "# Articles": len(df), + f"# Articles ({st.session_state.data_time_window} derniers jours)": len( + df_filtered + ), + } + + st.dataframe(pd.DataFrame([stats])) + + st.write(f"#### Données des derniers {st.session_state.data_time_window} jours") + st.dataframe( + df_filtered, + use_container_width=True, + hide_index=True, + column_config={"url": st.column_config.LinkColumn("url")}, + ) + + +@st.cache_data +def get_all_data(files: list[Path]) -> pd.DataFrame: + """Returns the data contained in the provided files as a single DataFrame.""" + if not files: + return pd.DataFrame() + dfs = [load_data(Path(f)) for f in files] + new_df = pd.concat(dfs).drop_duplicates( + subset=["title"], keep="first", inplace=False + ) + return new_df diff --git a/bertrend_apps/prospective_demo/llm_utils.py b/bertrend_apps/prospective_demo/llm_utils.py new file mode 100644 index 0000000..18f383d --- /dev/null +++ b/bertrend_apps/prospective_demo/llm_utils.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +import json + +from loguru import logger + +from bertrend import LLM_CONFIG +from bertrend.llm_utils.openai_client import OpenAI_Client +from bertrend.topic_analysis.prompts import TOPIC_DESCRIPTION_PROMPT + + +def generate_bertrend_topic_description( + topic_words: str, + topic_number: int, + texts: list[str], + language_code: str = "fr", +) -> dict: + """Generates a LLM-based human-readable description of a topic composed of a title and a description (as a dict)""" + if not texts: + logger.warning(f"No text found for topic number {topic_number}") + return {"title": "", "description": ""} + + topic_representation = ", ".join(topic_words.split("_")) # Get top 10 words + + # Prepare the documents text + docs_text = "\n\n".join( + [f"Document {i + 1}: {doc[0:2000]}..." for i, doc in enumerate(texts)] + ) + + # Prepare the prompt + prompt = TOPIC_DESCRIPTION_PROMPT[language_code] + try: + client = OpenAI_Client( + api_key=LLM_CONFIG["api_key"], + endpoint=LLM_CONFIG["endpoint"], + model=LLM_CONFIG["model"], + ) + answer = client.generate( + response_format={"type": "json_object"}, + user_prompt=prompt.format( + topic_representation=topic_representation, + docs_text=docs_text, + ), + ) + return json.loads(answer) + except Exception as e: + logger.error(f"Error calling OpenAI API: {e}") + return f"Error generating description: {str(e)}" diff --git a/bertrend_apps/prospective_demo/models_info.py b/bertrend_apps/prospective_demo/models_info.py new file mode 100644 index 0000000..fac1c6a --- /dev/null +++ b/bertrend_apps/prospective_demo/models_info.py @@ -0,0 +1,303 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +import random +import shutil +import sys +import time + +import pandas as pd +import streamlit as st +import toml +from loguru import logger + +from bertrend import BERTREND_LOG_PATH, BEST_CUDA_DEVICE, load_toml_config +from bertrend.demos.demos_utils.icons import ( + EDIT_ICON, + DELETE_ICON, + WARNING_ICON, + INFO_ICON, + TOGGLE_ON_ICON, + TOGGLE_OFF_ICON, +) +from bertrend_apps.common.crontab_utils import ( + check_cron_job, + remove_from_crontab, + add_job_to_crontab, +) +from bertrend_apps.prospective_demo import ( + get_user_models_path, + get_model_cfg_path, + DEFAULT_ANALYSIS_CFG, +) +from bertrend_apps.prospective_demo.streamlit_utils import clickable_df + + +@st.fragment +def models_monitoring(): + if not st.session_state.user_feeds: + st.stop() + + st.session_state.model_analysis_cfg = {} + displayed_list = [] + + for model_id in sorted(st.session_state.user_feeds.keys()): + try: + st.session_state.model_analysis_cfg[model_id] = load_toml_config( + get_model_cfg_path( + user_name=st.session_state.username, model_id=model_id + ) + ) + except Exception: + st.session_state.model_analysis_cfg[model_id] = DEFAULT_ANALYSIS_CFG + list_models = get_models_info(model_id) + displayed_list.append( + { + "id": model_id, + "# modèles": len(list_models) if list_models else 0, + "date 1er modèle": list_models[0] if list_models else None, + "date dernier modèle": list_models[-1] if list_models else None, + "fréquence mise à jour (# jours)": st.session_state.model_analysis_cfg[ + model_id + ]["model_config"]["granularity"], + "fenêtre d'analyse (# jours)": st.session_state.model_analysis_cfg[ + model_id + ]["model_config"]["window_size"], + } + ) + + st.session_state.models_paths = { + model_id: get_user_models_path(st.session_state.username, model_id) + for model_id in st.session_state.user_feeds.keys() + } + + df = pd.DataFrame(displayed_list) + if not df.empty: + df = df.sort_values(by="id", inplace=False).reset_index(drop=True) + + clickable_df_buttons = [ + (EDIT_ICON, edit_model_parameters, "secondary"), + (lambda x: toggle_icon(df, x), toggle_learning, "secondary"), + (DELETE_ICON, handle_delete_models, "primary"), + ] + clickable_df(df, clickable_df_buttons) + + +@st.dialog("Paramètres") +def edit_model_parameters(row_dict: dict): + model_id = row_dict["id"] + st.write(f"**Paramètres des modèles pour la veille {model_id}**") + + new_granularity = st.slider( + "Fréquence de mise à jour des modèles (en jours)", + min_value=1, + max_value=30, + value=st.session_state.model_analysis_cfg[model_id]["model_config"][ + "granularity" + ], + step=1, + help=f"{INFO_ICON} Sélection de la fréquence à laquelle la détection de sujets est effectuée. " + f"Le nombre de jours sélectionné doit être choisi pour s'assurer d'un volume de données suffisant.", + ) + new_window_size = st.slider( + "Sélection de la fenêtre temporelle (en jours)", + min_value=new_granularity, + max_value=30, + value=max( + st.session_state.model_analysis_cfg[model_id]["model_config"][ + "window_size" + ], + new_granularity, + ), + step=1, + help=f"{INFO_ICON} Sélection de la plage temporelle considérée pour calculer les différents " + f"types de signaux (faibles, forts)", + ) + + st.write(f"**Paramètres d'analyse de la veille {model_id}: éléments à inclure**") + topic_evolution = st.checkbox( + "Evolution du sujet", + value=st.session_state.model_analysis_cfg[model_id]["analysis_config"][ + "topic_evolution" + ], + ) + evolution_scenarios = st.checkbox( + "Scénarios d'évolution", + value=st.session_state.model_analysis_cfg[model_id]["analysis_config"][ + "evolution_scenarios" + ], + ) + multifactorial_analysis = st.checkbox( + "Analyse multifactorielle", + value=st.session_state.model_analysis_cfg[model_id]["analysis_config"][ + "multifactorial_analysis" + ], + ) + + model_config = { + "granularity": new_granularity, + "window_size": new_window_size, + "language": ( + "French" + if st.session_state.user_feeds[model_id]["data-feed"]["language"] == "fr" + else "English" + ), + } + analysis_config = { + "topic_evolution": topic_evolution, + "evolution_scenarios": evolution_scenarios, + "multifactorial_analysis": multifactorial_analysis, + } + + if st.button("OK"): + save_model_config( + model_id, {"model_config": model_config, "analysis_config": analysis_config} + ) + st.rerun() + + +def save_model_config(model_id: str, config: dict): + model_cfg_path = get_model_cfg_path( + user_name=st.session_state.username, model_id=model_id + ) + with open(model_cfg_path, "w") as toml_file: + toml.dump(config, toml_file) + logger.debug(f"Saved model analysis config {config} to {model_cfg_path}") + + +def load_model_config(model_id: str) -> dict: + model_cfg_path = get_model_cfg_path( + user_name=st.session_state.username, model_id=model_id + ) + try: + return load_toml_config(model_cfg_path) + except Exception: + return DEFAULT_ANALYSIS_CFG + + +@st.dialog("Confirmation") +def handle_delete_models(row_dict: dict): + """Function to handle remove models from cache""" + model_id = row_dict["id"] + st.write( + f":orange[{WARNING_ICON}] Voulez-vous vraiment supprimer tous les modèles stockés pour la veille **{model_id}** ?" + ) + col1, col2, _ = st.columns([2, 2, 8]) + with col1: + if st.button("Oui", type="primary"): + remove_scheduled_training_for_user( + model_id=model_id, user=st.session_state.username + ) + delete_cached_models(model_id) + logger.info(f"Modèles en cache supprimés pour la veille {model_id} !") + time.sleep(0.2) + st.rerun() + with col2: + if st.button("Non"): + st.rerun() + + +def toggle_learning(cfg: dict): + """Activate / deactivate the learning from the crontab""" + model_id = cfg["id"] + if check_if_learning_active_for_user( + model_id=model_id, user=st.session_state.username + ): + if remove_scheduled_training_for_user( + model_id=model_id, user=st.session_state.username + ): + st.toast( + f"Le learning pour la veille **{model_id}** est déactivé !", + icon=INFO_ICON, + ) + logger.info(f"Learning pour {model_id} désactivé !") + else: + schedule_training_for_user(model_id, st.session_state.username) + st.toast( + f"Le learning pour la veille **{model_id}** est activé !", icon=WARNING_ICON + ) + logger.info(f"Learning pour {model_id} activé !") + time.sleep(0.2) + st.rerun() + + +def toggle_icon(df: pd.DataFrame, index: int) -> str: + """Switch the toggle icon depending on the statis of the scrapping feed in the crontab""" + model_id = df["id"][index] + return ( + f":green[{TOGGLE_ON_ICON}]" + if check_if_learning_active_for_user( + model_id=model_id, user=st.session_state.username + ) + else f":red[{TOGGLE_OFF_ICON}]" + ) + + +def check_if_learning_active_for_user(model_id: str, user: str): + """Checks if a given scrapping feed is active (registered in the crontab""" + if user: + return check_cron_job(rf"process_new_data.*{user}.*{model_id}.*>") + else: + return False + + +def remove_scheduled_training_for_user(model_id: str, user: str): + """Removes from the crontab the training job matching the provided model_id""" + if user: + return remove_from_crontab(rf"process_new_data.*{user}.*{model_id}.*>") + + +def schedule_training_for_user(model_id: str, user: str): + """Schedule data scrapping on the basis of a feed configuration file""" + schedule = generate_crontab_expression( + st.session_state.model_analysis_cfg[model_id]["model_config"]["granularity"] + ) + logpath = BERTREND_LOG_PATH / "users" / user + logpath.mkdir(parents=True, exist_ok=True) + command = ( + f"{sys.prefix}/bin/python -m bertrend_apps.prospective_demo.process_new_data {user} {model_id} " + f"> {logpath}/learning_{model_id}.log 2>&1" + ) + env_vars = f"CUDA_VISIBLE_DEVICES={BEST_CUDA_DEVICE}" + add_job_to_crontab(schedule, command, env_vars) + + +def delete_cached_models(model_id: str): + """Removes models from the cache""" + # Remove the directory and all its contents + shutil.rmtree(st.session_state.models_paths[model_id]) + + +def generate_crontab_expression(days_interval: int) -> str: + # Random hour between 0 and 6 (inclusive) + hour = random.randint(0, 6) # run during the night + # Random minute rounded to the nearest 10 + minute = random.choice([0, 10, 20, 30, 40, 50]) + # Compute days + days = [str(i) for i in range(1, 31, days_interval)] + # Crontab expression format: minute hour day_of_month month day_of_week + crontab_expression = f"{minute} {hour} {','.join(days)} * *" + return crontab_expression + + +def safe_timestamp(x: str) -> pd.Timestamp | None: + try: + return pd.Timestamp(x) + except Exception as e: + return None + + +def get_models_info(model_id: str) -> list: + """Returns the list of topic models that are stored, identified by their timestamp""" + user_model_dir = get_user_models_path(st.session_state.username, model_id) + if not user_model_dir.exists(): + return [] + matching_files = user_model_dir.glob(r"????-??-??") + return sorted( + [ + safe_timestamp(x.name) + for x in matching_files + if safe_timestamp(x.name) is not None + ] + ) diff --git a/bertrend_apps/prospective_demo/process_new_data.py b/bertrend_apps/prospective_demo/process_new_data.py new file mode 100644 index 0000000..5eae98a --- /dev/null +++ b/bertrend_apps/prospective_demo/process_new_data.py @@ -0,0 +1,192 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +from datetime import timedelta +from pathlib import Path + +import pandas as pd +import typer +from jsonlines import jsonlines +from loguru import logger + +from bertrend import load_toml_config, FEED_BASE_PATH +from bertrend.BERTrend import train_new_data, BERTrend +from bertrend.services.embedding_service import EmbeddingService +from bertrend.trend_analysis.weak_signals import analyze_signal +from bertrend.utils.data_loading import load_data, split_data +from bertrend_apps.prospective_demo import ( + get_user_feed_path, + get_user_models_path, + INTERPRETATION_PATH, + NOISE, + WEAK_SIGNALS, + STRONG_SIGNALS, + LLM_TOPIC_DESCRIPTION_COLUMN, + DEFAULT_ANALYSIS_CFG, + get_model_cfg_path, +) +from bertrend_apps.prospective_demo.llm_utils import generate_bertrend_topic_description + +DEFAULT_TOP_K = 5 + +if __name__ == "__main__": + app = typer.Typer() + + @app.command("train-new-model") + def train_new_model( + user_name: str = typer.Argument(help="Identifier of the user"), + model_id: str = typer.Argument(help="ID of the model/data to train"), + ): + # Load model & analysis config + model_cfg_path = get_model_cfg_path(user_name, model_id) + try: + model_analysis_cfg = load_toml_config(model_cfg_path) + except Exception: + model_analysis_cfg = DEFAULT_ANALYSIS_CFG + # Extract relevant values + granularity = model_analysis_cfg["model_config"]["granularity"] + window_size = model_analysis_cfg["model_config"]["window_size"] + language = model_analysis_cfg["model_config"]["language"] + if language not in ["French", "English"]: + language = "French" + language_code = "fr" if language == "French" else "en" + + # Path to previously saved models for those data and this user + bertrend_models_path = get_user_models_path(user_name, model_id) + + # Initialization of embedding service + # TODO: customize service (lang, etc) + embedding_service = EmbeddingService(local=True) + + # load data for last period + # TODO: to be improved + cfg_file = get_user_feed_path(user_name, model_id) + if not cfg_file.exists(): + logger.error(f"Cannot find/process config file: {cfg_file}") + return + cfg = load_toml_config(cfg_file) + feed_base_dir = cfg["data-feed"]["feed_dir_path"] + files = list( + Path(FEED_BASE_PATH, feed_base_dir).glob( + f"*{cfg['data-feed'].get('id')}*.jsonl*" + ) + ) + if not files: + logger.warning(f"No new data for '{model_id}', nothing to do") + return + + dfs = [load_data(Path(f), language=language) for f in files] + new_data = pd.concat(dfs).drop_duplicates( + subset=["title"], keep="first", inplace=False + ) + + # filter data according to granularity + # Calculate the date X days ago + reference_timestamp = pd.Timestamp( + new_data["timestamp"].max().date() + ) # used to identify the last model + cut_off_date = new_data["timestamp"].max() - timedelta(days=granularity) + # Filter the DataFrame to keep only the rows within the last X days + filtered_df = new_data[new_data["timestamp"] >= cut_off_date] + + filtered_df = split_data(filtered_df) + + logger.info(f'Processing new data for user "{user_name}" about "{model_id}"...') + # Process new data + bertrend = train_new_data( + filtered_df, + bertrend_models_path=bertrend_models_path, + embedding_service=embedding_service, + language=language, + granularity=granularity, + ) + + if not bertrend._are_models_merged: + # This is generally the case when we have only one model + return + + # Compute popularities + bertrend.calculate_signal_popularity() + + # classify last signals + noise_topics_df, weak_signal_topics_df, strong_signal_topics_df = ( + bertrend.classify_signals(window_size, cut_off_date) + ) + + # LLM-based interpretation + interpretation_path = ( + bertrend_models_path + / INTERPRETATION_PATH + / reference_timestamp.strftime("%Y-%m-%d") + ) + interpretation_path.mkdir(parents=True, exist_ok=True) + for df, df_name in zip( + [noise_topics_df, weak_signal_topics_df, strong_signal_topics_df], + [NOISE, WEAK_SIGNALS, STRONG_SIGNALS], + ): + if not df.empty: + # enrich signal description with LLM-based topic description + df[LLM_TOPIC_DESCRIPTION_COLUMN] = df.apply( + lambda row: generate_bertrend_topic_description( + topic_words=row["Representation"], + topic_number=row["Topic"], + texts=row["Documents"], + language_code=language_code, + ), + axis=1, + ) + df.to_parquet(f"{interpretation_path}/{df_name}.parquet") + + # Obtain detailed LLM-based interpretion for signals + generate_llm_interpretation( + bertrend, + reference_timestamp=reference_timestamp, + df=df, + df_name=df_name, + output_path=interpretation_path, + ) + + def generate_llm_interpretation( + bertrend: BERTrend, + reference_timestamp: pd.Timestamp, + df: pd.DataFrame, + df_name: str, + output_path: Path, + top_k: int = DEFAULT_TOP_K, + ): + """ + Generate detailed analysis for the top k topics using parallel processing. + + Args: + bertrend: BERTrend instance + reference_timestamp: Reference timestamp for analysis + df: Input DataFrame + df_name: Name of the DataFrame for output + output_path: Path to save the results + top_k: Number of top topics to analyze + """ + + interpretation = [] + for topic in df.sort_values(by=["Latest_Popularity"], ascending=False).head( + top_k + )["Topic"]: + summary, analysis, formatted_html = analyze_signal( + bertrend, topic, reference_timestamp + ) + interpretation.append( + {"topic": topic, "summary": summary, "analysis": formatted_html} + ) + + # Save interpretation + output_file_name = output_path / f"{df_name}_interpretation.jsonl" + with jsonlines.open( + output_file_name, + mode="w", + ) as writer: + for item in interpretation: + writer.write(item) + logger.success(f"Interpretation saved to: {output_file_name}") + + # Main app + app() diff --git a/bertrend_apps/prospective_demo/report_generation.py b/bertrend_apps/prospective_demo/report_generation.py new file mode 100644 index 0000000..c910d92 --- /dev/null +++ b/bertrend_apps/prospective_demo/report_generation.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +import pandas as pd +import streamlit as st + +from bertrend.demos.demos_utils.icons import NEWSLETTER_ICON, TOPIC_ICON +from bertrend_apps.prospective_demo import ( + WEAK_SIGNALS, + STRONG_SIGNALS, + LLM_TOPIC_DESCRIPTION_COLUMN, +) + +WEAK_SIGNAL_NB = 3 +STRONG_SIGNAL_NB = 5 + + +@st.fragment +def reporting(): + st.selectbox( + "Sélection de la veille", + options=sorted(st.session_state.user_feeds.keys()), + key="report_id", + ) + + tab1, tab2 = st.tabs( + [ + TOPIC_ICON + " Etape 1: Sélection des sujets à retenir", + NEWSLETTER_ICON + " Etape 2", + ] + ) + with tab1: + choose_topics() + + # generate_newsletter() + + +def choose_topics(): + st.subheader("Etape 1: Sélection des sujets à retenir") + model_id = st.session_state.report_id + cols = st.columns(2) + with cols[0]: + st.write("#### :orange[Sujets émergents]") + st.session_state.weak_topics_list = choose_from_df( + st.session_state.signal_interpretations[model_id][WEAK_SIGNALS] + ) + with cols[1]: + st.write("#### :green[Sujets forts]") + st.session_state.strong_topics_list = choose_from_df( + st.session_state.signal_interpretations[model_id][STRONG_SIGNALS] + ) + + +def choose_from_df(df: pd.DataFrame): + df["A retenir"] = True + df["Sujet"] = df[LLM_TOPIC_DESCRIPTION_COLUMN].apply(lambda r: r["title"]) + df["Description"] = df[LLM_TOPIC_DESCRIPTION_COLUMN].apply( + lambda r: r["description"] + ) + columns = ["Topic", "A retenir", "Sujet", "Description"] + pd.DataFrame( + [ + {"command": "st.selectbox", "rating": 4, "is_widget": True}, + {"command": "st.balloons", "rating": 5, "is_widget": False}, + {"command": "st.time_input", "rating": 3, "is_widget": True}, + ] + ) + edited_df = st.data_editor(df[columns], num_rows="dynamic", column_order=columns) + selection = edited_df[edited_df["A retenir"] == True]["Topic"].tolist() + return selection diff --git a/bertrend_apps/prospective_demo/signal_analysis.py b/bertrend_apps/prospective_demo/signal_analysis.py new file mode 100644 index 0000000..0f66425 --- /dev/null +++ b/bertrend_apps/prospective_demo/signal_analysis.py @@ -0,0 +1,12 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +import streamlit as st + + +@st.fragment +def signal_analysis(): + st.write( + "Ici mettre seulement les tableaux weak / strong + les liens vers les articles" + ) diff --git a/bertrend_apps/prospective_demo/start_demo.sh b/bertrend_apps/prospective_demo/start_demo.sh new file mode 100755 index 0000000..c150877 --- /dev/null +++ b/bertrend_apps/prospective_demo/start_demo.sh @@ -0,0 +1,9 @@ +# +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +# + +# Starts the Trend Analysis application +CUDA_VISIBLE_DEVICES=0 streamlit run --theme.primaryColor royalblue app.py \ No newline at end of file diff --git a/bertrend_apps/prospective_demo/streamlit_utils.py b/bertrend_apps/prospective_demo/streamlit_utils.py new file mode 100644 index 0000000..78f037e --- /dev/null +++ b/bertrend_apps/prospective_demo/streamlit_utils.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +import zlib +from typing import Callable + +import pandas as pd +import streamlit as st + + +def clickable_df( + df: pd.DataFrame, clickable_buttons: list[tuple[str | Callable, Callable, str]] +): + """Streamlit display of a df-like rendering with additional clickable columns (buttons).""" + if df is None or df.empty: + return + cols = st.columns(len(df.columns) * [3] + len(clickable_buttons) * [1]) + for i, c in enumerate(df.columns): + with cols[i]: + st.write(f"**{c}**") + # Generate a unique identifier, this will be used to identify the keys in case multiple clickable_df are used + unique_id = zlib.crc32(" ".join(df.columns.tolist()).encode()) + for index, row in df.iterrows(): + # Create a clickable container for each row + cols = st.columns(len(df.columns) * [3] + len(clickable_buttons) * [1]) + for i, col in enumerate(cols[: -len(clickable_buttons)]): + with col: + st.write(row[df.columns[i]]) + # Render the additional columns (clickable) + for i, button in enumerate(clickable_buttons): + with cols[len(df.columns) + i]: + button_label = button[0](index) if callable(button[0]) else button[0] + if st.button( + button_label, key=f"button{unique_id}_{i}_{index}", type=button[2] + ): + button[1](df.iloc[index].to_dict()) diff --git a/getting_started/bertrend_quickstart.ipynb b/getting_started/bertrend_quickstart.ipynb new file mode 100644 index 0000000..a372812 --- /dev/null +++ b/getting_started/bertrend_quickstart.ipynb @@ -0,0 +1,895 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "73870050e69c50e6", + "metadata": {}, + "source": [ + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "10a9d82c667c7fbe", + "metadata": {}, + "source": [ + "# BERTrend quickstart\n", + "The purpose of this notebook is to complement the existing demos available in the directory `bertrend/demos` with some code examples that explain how to integrate BERTrend with your application code." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "849734b0d71f2495", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:00:24.370757Z", + "start_time": "2025-01-20T15:00:24.349873Z" + } + }, + "outputs": [], + "source": [ + "\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "id": "a795490c2d3e539e", + "metadata": {}, + "source": [ + "## BERTrend installation" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ba4a7eacde91b892", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-26T21:07:28.378082Z", + "start_time": "2025-01-26T21:07:28.370941Z" + } + }, + "outputs": [], + "source": [ + "import json\n", + "from pathlib import Path\n", + "import pandas as pd\n", + "from pandas import Timestamp\n", + "from IPython.display import display\n", + "from loguru import logger\n", + "\n", + "from bertrend import DATA_PATH\n", + "from bertrend.BERTrend import BERTrend\n", + "from bertrend import MODELS_DIR\n", + "from bertrend.utils.data_loading import load_data, split_data, TEXT_COLUMN\n", + "from bertrend.services.embedding_service import EmbeddingService\n", + "from bertrend.BERTopicModel import BERTopicModel\n", + "from bertrend.topic_analysis.topic_description import generate_topic_description\n", + "from bertrend.trend_analysis.weak_signals import analyze_signal\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "74702a2391f80f72", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-26T21:07:30.328141Z", + "start_time": "2025-01-26T21:07:30.324568Z" + } + }, + "outputs": [], + "source": [ + "#!pip install bertrend" + ] + }, + { + "cell_type": "markdown", + "id": "ca03bdd5398b56b3", + "metadata": {}, + "source": [ + "### Configuration of topic models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b97d93ac81a4d420", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:00:35.343828Z", + "start_time": "2025-01-20T15:00:35.298417Z" + } + }, + "outputs": [], + "source": [ + "# Topic model with default parameters - each parameter of BERTopic can be modified from the constructor or can be read from a configuration file\n", + "# overrides the default config to use English\n", + "config = '''\n", + "# Default configuration file to be used for topic model\n", + "\n", + "# Global parameters\n", + "[global]\n", + "language = \"English\"\n", + "\n", + "# BERTopic parameters: https://maartengr.github.io/BERTopic/api/bertopic.html#bertopic._bertopic.BERTopic.__init__\n", + "[bertopic_model]\n", + "top_n_words = 10\n", + "verbose = true\n", + "representation_model = [\"MaximalMarginalRelevance\"] # KeyBERTInspired, OpenAI\n", + "zeroshot_topic_list = []\n", + "zeroshot_min_similarity = 0\n", + "\n", + "# UMAP parameters: https://umap-learn.readthedocs.io/en/latest/api.html\n", + "[umap_model]\n", + "n_neighbors = 5\n", + "n_components = 5\n", + "min_dist = 0.0\n", + "metric = \"cosine\"\n", + "random_state = 42\n", + "\n", + "# HDBSCAN parameters: https://hdbscan.readthedocs.io/en/latest/api.html\n", + "[hdbscan_model]\n", + "min_cluster_size = 5\n", + "min_samples = 5\n", + "metric = \"euclidean\"\n", + "cluster_selection_method = \"eom\"\n", + "prediction_data = true\n", + "\n", + "# CountVectorizer: https://scikit-learn.org/1.5/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html\n", + "[vectorizer_model]\n", + "ngram_range = [1, 1]\n", + "stop_words = true # If true, will check `language` parameter and load associated stopwords file\n", + "min_df = 2\n", + "\n", + "# ClassTfidfTransformer: https://maartengr.github.io/BERTopic/api/ctfidf.html\n", + "[ctfidf_model]\n", + "bm25_weighting = false\n", + "reduce_frequent_words = true\n", + "\n", + "# MaximalMarginalRelevance: https://maartengr.github.io/BERTopic/api/representation/mmr.html\n", + "[mmr_model]\n", + "diversity = 0.3\n", + "\n", + "# Reduce outliers: https://maartengr.github.io/BERTopic/api/bertopic.html#bertopic._bertopic.BERTopic.reduce_outliers\n", + "[reduce_outliers]\n", + "strategy = \"c-tf-idf\"\n", + "'''\n", + "\n", + "topic_model = BERTopicModel(config)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa92f4b55e7b7b72", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:00:35.547370Z", + "start_time": "2025-01-20T15:00:35.486087Z" + } + }, + "outputs": [], + "source": [ + "# The TopicModel class is mainly a wrapper around BERTopic and can be used as-is, for example for a first analysis of data (without considering evolving trends, but this is not mandatory at all)\n" + ] + }, + { + "cell_type": "markdown", + "id": "7cfd832467877a23", + "metadata": {}, + "source": [ + "## Using BERTrend for retrospective analysis" + ] + }, + { + "cell_type": "markdown", + "id": "6a07ec11284b82cb", + "metadata": {}, + "source": [ + "### Instantiation of BERTrend\n" + ] + }, + { + "cell_type": "markdown", + "id": "c5118dce73f8cfce", + "metadata": {}, + "source": [ + "In the case of a **retrospective trend analysis** task, the goal is to identify and evaluate patterns or changes over time within a dataset, allowing for insights into historical performance, behaviors, or events that can inform future decision-making and strategy development.\n", + "\n", + "In this context, the general principle consists in splitting the past data into different time slices. Then each dataset is used to train a separate topic models. Each topic model description corresponding to the older data slice is merged with the next one and decay factors are applied. This allows to have a vision of topic evolution over time" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52bc66eed5bb040", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:00:35.784959Z", + "start_time": "2025-01-20T15:00:35.745153Z" + } + }, + "outputs": [], + "source": [ + "# Basic creation of the object and parametrization\n", + "# BERTrend uses several topic models; therefore, it is necessary to pass a topic_model object as a reference\n", + "bertrend = BERTrend(topic_model=topic_model)" + ] + }, + { + "cell_type": "markdown", + "id": "bf7cd6699bf77299", + "metadata": {}, + "source": [ + "### 1. Gather historical data to be analyzed\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "154fb553f7004986", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:00:35.978219Z", + "start_time": "2025-01-20T15:00:35.813108Z" + } + }, + "outputs": [], + "source": [ + "# Here some Trump tweets from: https://github.com/MarkHershey/CompleteTrumpTweetsArchive/blob/master/data/realDonaldTrump_in_office.csv\n", + "#!wget \"https://raw.githubusercontent.com/MarkHershey/CompleteTrumpTweetsArchive/refs/heads/master/data/realDonaldTrump_in_office.csv\"\n", + "df = pd.read_csv(\"realDonaldTrump_in_office.csv\", sep=',',quotechar='\"', skipinitialspace=True)\n", + "# BERTrend expects specific data format\n", + "df = df.rename(columns={'Time': 'timestamp', 'Tweet URL': 'url', \"Tweet Text\": \"text\"})\n", + "df[\"source\"]=df[\"ID\"]\n", + "df[\"document_id\"] = df.index\n", + "df.reset_index(inplace=True, drop=True)\n", + "df.head(5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2e8b96b46718241", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:00:36.089939Z", + "start_time": "2025-01-20T15:00:36.031108Z" + } + }, + "outputs": [], + "source": [ + "df.index" + ] + }, + { + "cell_type": "markdown", + "id": "9d26753d9496a25", + "metadata": {}, + "source": [ + "### 2. Embed data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ca3e17198fdbb6a", + "metadata": {}, + "outputs": [], + "source": [ + "# Selection of a subset of data\n", + "df = df.head(1000)\n", + "\n", + "#embedding_service_cfg = {\"local\": False, \"host\":\"10.132.5.44\", \"port\": 6464}\n", + "\n", + "#embedding_service = EmbeddingService(**embedding_service_cfg)\n", + "embedding_service = EmbeddingService()\n", + "embeddings, token_strings, token_embeddings = embedding_service.embed(\n", + " texts=df[\"text\"],\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72df96f5c7d8d52b", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:01:16.856529Z", + "start_time": "2025-01-20T15:01:16.812294Z" + } + }, + "outputs": [], + "source": [ + "embedding_model_name = embedding_service.embedding_model_name\n" + ] + }, + { + "cell_type": "markdown", + "id": "2e94b24d1ef107a2", + "metadata": {}, + "source": [ + "### 3. Split the data into time slices\n", + "\n", + "This can be done manually for some reason or can be done automatically based on a specified time granularity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ea313bff64c8cce", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:01:16.964906Z", + "start_time": "2025-01-20T15:01:16.921763Z" + } + }, + "outputs": [], + "source": [ + "from bertrend.utils.data_loading import group_by_days, load_data\n", + "\n", + "day_granularity = 30\n", + "grouped_data = group_by_days(df=df, day_granularity=day_granularity)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a89b3c810c4575bc", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:01:17.040491Z", + "start_time": "2025-01-20T15:01:16.997388Z" + } + }, + "outputs": [], + "source": [ + "# Number of sliced data\n", + "len(grouped_data)" + ] + }, + { + "cell_type": "markdown", + "id": "9d7ffa03a6ed9330", + "metadata": {}, + "source": [ + "### 4. Train topic models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e11789ecb115639", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:07:11.584568Z", + "start_time": "2025-01-20T15:01:17.180822Z" + } + }, + "outputs": [], + "source": [ + "bertrend.train_topic_models(grouped_data=grouped_data, embedding_model=embedding_model_name, embeddings=embeddings)" + ] + }, + { + "cell_type": "markdown", + "id": "855c151c8cd9f93d", + "metadata": {}, + "source": [ + "### 5. (Optional) Save trained_models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a54146c6b5f591b", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:07:12.523789Z", + "start_time": "2025-01-20T15:07:12.377692Z" + }, + "scrolled": true + }, + "outputs": [], + "source": [ + "bertrend.save_models()" + ] + }, + { + "cell_type": "markdown", + "id": "6d76285c9be44e92", + "metadata": {}, + "source": [ + "### 6. Merge models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a95fd062728118e9", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:07:13.179985Z", + "start_time": "2025-01-20T15:07:12.853779Z" + } + }, + "outputs": [], + "source": [ + "bertrend.merge_all_models()" + ] + }, + { + "cell_type": "markdown", + "id": "d5cbf21f65102cd5", + "metadata": {}, + "source": [ + "### 7. Calculate signal popularity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94859eb8b9944224", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:07:13.819430Z", + "start_time": "2025-01-20T15:07:13.579473Z" + } + }, + "outputs": [], + "source": [ + "bertrend.calculate_signal_popularity()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a989f7d97083e70", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:07:13.939621Z", + "start_time": "2025-01-20T15:07:13.854683Z" + } + }, + "outputs": [], + "source": [ + "# List of topic models\n", + "bertrend.topic_models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dcba20eeaef6b472", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:07:14.331855Z", + "start_time": "2025-01-20T15:07:14.116053Z" + } + }, + "outputs": [], + "source": [ + "window_size = 30\n", + "\n", + "# List of strong and weak signals over time\n", + "for ts in bertrend.topic_models.keys():\n", + " print(ts)\n", + " noise_topics_df, weak_signal_topics_df, strong_signal_topics_df = bertrend.classify_signals(window_size, ts)\n", + " if not weak_signal_topics_df.empty:\n", + " print(\"Weak signals\")\n", + " display(weak_signal_topics_df[[\"Topic\",\"Representation\"]].head(5))\n", + " if not strong_signal_topics_df.empty:\n", + " print(\"Strong signals\")\n", + " display(strong_signal_topics_df[[\"Topic\",\"Representation\"]].head(5))\n", + " print()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4582c0cb6c1f6186", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-26T19:58:55.911033Z", + "start_time": "2025-01-26T19:58:55.907556Z" + } + }, + "outputs": [], + "source": [ + "# selection of one particular timestamp to look at\n", + "selected_timestamp = Timestamp('2017-04-20 00:00:00')\n", + "selected_topic_model = bertrend.topic_models.get(selected_timestamp)\n" + ] + }, + { + "cell_type": "markdown", + "id": "e31285ee5eb9d9f6", + "metadata": {}, + "source": [ + "### Get topic description\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c945b625df18d881", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:09:22.901513Z", + "start_time": "2025-01-20T15:09:22.731495Z" + } + }, + "outputs": [], + "source": [ + "desc = generate_topic_description(topic_model=selected_topic_model, topic_number=5, filtered_docs=df, language_code=\"en\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e61b903379a0fbd1", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:23:13.250764Z", + "start_time": "2025-01-20T15:23:11.647929Z" + } + }, + "outputs": [], + "source": [ + "desc[\"title\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4dbdd4998e0956a", + "metadata": {}, + "outputs": [], + "source": [ + "desc[\"description\"]" + ] + }, + { + "cell_type": "markdown", + "id": "e27e46b0adc6e88b", + "metadata": {}, + "source": [ + "### Get topic analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cdc44ef6f558aac0", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:07:17.430211Z", + "start_time": "2025-01-20T15:07:16.745674Z" + } + }, + "outputs": [], + "source": [ + "summary, analysis, formatted_html = analyze_signal(bertrend, 7, selected_timestamp)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "531558c5b600cb30", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T15:07:17.433647087Z", + "start_time": "2025-01-19T14:38:52.904786Z" + } + }, + "outputs": [], + "source": [ + "from IPython.display import display, HTML\n", + "display(HTML(formatted_html))" + ] + }, + { + "cell_type": "markdown", + "id": "d4c54df2e25f24c9", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "id": "c922549ec07859a9", + "metadata": {}, + "source": [ + "## Using BERTrend for prospective analysis" + ] + }, + { + "cell_type": "markdown", + "id": "cd3a1210eb53e1e2", + "metadata": {}, + "source": [ + "In the case of a **prospective trend analysis task**, the goal is to **forecast future** developments or outcomes based on current data and trends, enabling organizations to make informed decisions, allocate resources effectively, and strategize for upcoming challenges or opportunities.\n" + ] + }, + { + "cell_type": "markdown", + "id": "100f841b083ce637", + "metadata": {}, + "source": [ + "In this example, we are going to simulate a prospective task:\n", + "- we simulate new data coming in\n", + "- for each new data, we will compute the new topic model, merge it to previous one and detect at each iteration strong and weak signals\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4d88b099fc25b600", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-26T20:42:31.675644Z", + "start_time": "2025-01-26T20:42:31.671870Z" + } + }, + "outputs": [], + "source": [ + "MY_DATA_DIR = DATA_PATH / \"feeds/feed_sobriete\"\n", + "\n", + "input_data = [\n", + " MY_DATA_DIR / \"2024-12-30_feed_sobriete.jsonl\",\n", + " MY_DATA_DIR / \"2025-01-06_feed_sobriete.jsonl\",\n", + " MY_DATA_DIR / \"2025-01-20_feed_sobriete.jsonl\",\n", + "]\n", + "\n", + "window_size = 7" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a4619e8b7e9fbf91", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-26T21:04:04.405304Z", + "start_time": "2025-01-26T21:04:04.401150Z" + } + }, + "outputs": [], + "source": [ + "embedding_service_cfg = {\"local\": False, \"host\":\"10.132.5.44\", \"port\": 6464}\n", + "\n", + "embedding_service = EmbeddingService(**embedding_service_cfg)\n", + "embedding_model_name = embedding_service.embedding_model_name" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "29f00b403ea81df1", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-26T20:41:23.289362Z", + "start_time": "2025-01-26T20:41:23.284555Z" + } + }, + "outputs": [], + "source": [ + "BERTREND_MODELS_PATH = MODELS_DIR / \"sobriete_models\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "63e3d13a7d8c0cb", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-26T21:08:01.471923Z", + "start_time": "2025-01-26T21:08:01.464866Z" + } + }, + "outputs": [], + "source": [ + "def process_new_data(data_slice_path: Path, timestamp: pd.Timestamp):\n", + " logger.debug(f\"Processing new data: {data_slice_path}\")\n", + "\n", + " # Restore previous models\n", + " try:\n", + " bertrend = BERTrend.restore_models(BERTREND_MODELS_PATH)\n", + " except:\n", + " logger.warning(\"Cannot restore previous models, creating new one\")\n", + " bertrend = BERTrend(topic_model=BERTopicModel())\n", + "\n", + " # Read data\n", + " df = load_data(data_slice_path, language=\"French\")\n", + " df = split_data(df)\n", + " text = df[TEXT_COLUMN]\n", + "\n", + " # Embed new data\n", + " embeddings, token_strings, token_embeddings = embedding_service.embed(\n", + " texts=text,\n", + " )\n", + "\n", + " # Create topic model for new data\n", + " bertrend.train_topic_models({timestamp: df}, embeddings=embeddings, embedding_model=embedding_model_name)\n", + " \n", + " # Merge models\n", + " bertrend.merge_all_models()\n", + "\n", + " logger.info(f\"BERTrend contains {len(bertrend.topic_models)} topic models\")\n", + " \n", + " # Save models\n", + " bertrend.save_models(models_path=BERTREND_MODELS_PATH)\n", + "\n", + " \n", + " if not bertrend._are_models_merged:\n", + " return None\n", + " \n", + " # Compute popularities\n", + " bertrend.calculate_signal_popularity()\n", + " \n", + " # classify last signals\n", + " noise_topics_df, weak_signal_topics_df, strong_signal_topics_df = bertrend.classify_signals(window_size, timestamp)\n", + " # TODO: save dfs\n", + "\n", + " if weak_signal_topics_df.empty:\n", + " return None\n", + " \n", + " wt = weak_signal_topics_df['Topic']\n", + " logger.info(f\"Weak topics: {wt}\")\n", + " wt_list = []\n", + " for topic in wt:\n", + " desc = generate_topic_description(topic_model=bertrend.topic_models[timestamp], topic_number=topic, filtered_docs=df, language_code=\"fr\")\n", + " wt_list.append({\"timestamp\": timestamp, \"topic\": topic, \"title\": desc[\"title\"], \"description\": desc[\"description\"]})\n", + "\n", + " return pd.DataFrame(wt_list)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b36e0e226103b8c", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-26T21:05:36.124752Z", + "start_time": "2025-01-26T21:05:36.122652Z" + } + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2695805f56be632", + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-26T21:12:53.800721Z", + "start_time": "2025-01-26T21:08:10.434372Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2025-01-27 18:02:32.141\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mprocess_new_data\u001b[0m:\u001b[36m2\u001b[0m - \u001b[34m\u001b[1mProcessing new data: /scratch/nlp/data/bertrend/feeds/feed_sobriete/2024-12-30_feed_sobriete.jsonl\u001b[0m\n", + "\u001b[32m2025-01-27 18:02:32.142\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36mrestore_models\u001b[0m:\u001b[36m668\u001b[0m - \u001b[1mLoading models from: /scratch/nlp/cache/bertrend/models/sobriete_models\u001b[0m\n", + "\u001b[32m2025-01-27 18:02:32.335\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.services.embedding_service\u001b[0m:\u001b[36m_remote_embed_documents\u001b[0m:\u001b[36m203\u001b[0m - \u001b[34m\u001b[1mComputing embeddings...\u001b[0m\n", + "\u001b[32m2025-01-27 18:02:49.922\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.services.embedding_service\u001b[0m:\u001b[36m_remote_embed_documents\u001b[0m:\u001b[36m210\u001b[0m - \u001b[34m\u001b[1mComputing embeddings done for batch\u001b[0m\n", + "\u001b[32m2025-01-27 18:02:49.937\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.services.embedding_service\u001b[0m:\u001b[36m_get_remote_model_name\u001b[0m:\u001b[36m226\u001b[0m - \u001b[34m\u001b[1mModel name: OrdalieTech/Solon-embeddings-large-0.1\u001b[0m\n", + "\u001b[32m2025-01-27 18:02:49.938\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36mtrain_topic_models\u001b[0m:\u001b[36m240\u001b[0m - \u001b[1mTraining topic model 1/1...\u001b[0m\n", + "\u001b[32m2025-01-27 18:02:49.942\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36m_train_by_period\u001b[0m:\u001b[36m148\u001b[0m - \u001b[34m\u001b[1mProcessing period: 2024-12-30 00:00:00\u001b[0m\n", + "\u001b[32m2025-01-27 18:02:49.942\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36m_train_by_period\u001b[0m:\u001b[36m149\u001b[0m - \u001b[34m\u001b[1mNumber of documents: 932\u001b[0m\n", + "\u001b[32m2025-01-27 18:02:49.942\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36m_train_by_period\u001b[0m:\u001b[36m151\u001b[0m - \u001b[34m\u001b[1mCreating topic model...\u001b[0m\n", + "\u001b[32m2025-01-27 18:02:49.943\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.BERTopicModel\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m212\u001b[0m - \u001b[34m\u001b[1m\tInitializing BERTopic model\u001b[0m\n", + "\u001b[32m2025-01-27 18:02:49.943\u001b[0m | \u001b[32m\u001b[1mSUCCESS \u001b[0m | \u001b[36mbertrend.BERTopicModel\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m222\u001b[0m - \u001b[32m\u001b[1m\tBERTopic model instance created successfully\u001b[0m\n", + "\u001b[32m2025-01-27 18:02:49.943\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.BERTopicModel\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m224\u001b[0m - \u001b[34m\u001b[1m\tFitting BERTopic model\u001b[0m\n", + "2025-01-27 18:02:55,866 - BERTopic - Dimensionality - Fitting the dimensionality reduction algorithm\n", + "2025-01-27 18:03:05,012 - BERTopic - Dimensionality - Completed ✓\n", + "2025-01-27 18:03:05,014 - BERTopic - Cluster - Start clustering the reduced embeddings\n", + "2025-01-27 18:03:05,054 - BERTopic - Cluster - Completed ✓\n", + "2025-01-27 18:03:05,059 - BERTopic - Representation - Extracting topics from clusters using representation models.\n", + "2025-01-27 18:03:10,131 - BERTopic - Representation - Completed ✓\n", + "\u001b[32m2025-01-27 18:03:10.283\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.BERTopicModel\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m231\u001b[0m - \u001b[34m\u001b[1m\tReducing outliers\u001b[0m\n", + "2025-01-27 18:03:10,292 - BERTopic - WARNING: Using a custom list of topic assignments may lead to errors if topic reduction techniques are used afterwards. Make sure that manually assigning topics is the last step in the pipeline.Note that topic embeddings will also be created through weightedc-TF-IDF embeddings instead of centroid embeddings.\n", + "\u001b[32m2025-01-27 18:03:14.123\u001b[0m | \u001b[32m\u001b[1mSUCCESS \u001b[0m | \u001b[36mbertrend.BERTopicModel\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m259\u001b[0m - \u001b[32m\u001b[1m\tBERTopic model fitted successfully\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:14.124\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36m_train_by_period\u001b[0m:\u001b[36m158\u001b[0m - \u001b[34m\u001b[1mTopic model created successfully\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:14.151\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36mtrain_topic_models\u001b[0m:\u001b[36m244\u001b[0m - \u001b[34m\u001b[1mSuccessfully processed period: 2024-12-30 00:00:00\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:14.151\u001b[0m | \u001b[32m\u001b[1mSUCCESS \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36mtrain_topic_models\u001b[0m:\u001b[36m269\u001b[0m - \u001b[32m\u001b[1mFinished training all topic models\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:14.219\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36mmerge_all_models\u001b[0m:\u001b[36m294\u001b[0m - \u001b[33m\u001b[1mThis function requires at least two topic models. Ignored\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:14.219\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mprocess_new_data\u001b[0m:\u001b[36m27\u001b[0m - \u001b[1mBERTrend contains 1 topic models\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:14.247\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36msave_models\u001b[0m:\u001b[36m661\u001b[0m - \u001b[1mModels saved to: /scratch/nlp/cache/bertrend/models/sobriete_models\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "None" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2025-01-27 18:03:14.252\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36mprocess_new_data\u001b[0m:\u001b[36m2\u001b[0m - \u001b[34m\u001b[1mProcessing new data: /scratch/nlp/data/bertrend/feeds/feed_sobriete/2025-01-06_feed_sobriete.jsonl\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:14.252\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36mrestore_models\u001b[0m:\u001b[36m668\u001b[0m - \u001b[1mLoading models from: /scratch/nlp/cache/bertrend/models/sobriete_models\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:17.627\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.services.embedding_service\u001b[0m:\u001b[36m_remote_embed_documents\u001b[0m:\u001b[36m203\u001b[0m - \u001b[34m\u001b[1mComputing embeddings...\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:32.419\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.services.embedding_service\u001b[0m:\u001b[36m_remote_embed_documents\u001b[0m:\u001b[36m210\u001b[0m - \u001b[34m\u001b[1mComputing embeddings done for batch\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:32.434\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.services.embedding_service\u001b[0m:\u001b[36m_get_remote_model_name\u001b[0m:\u001b[36m226\u001b[0m - \u001b[34m\u001b[1mModel name: OrdalieTech/Solon-embeddings-large-0.1\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:32.434\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36mtrain_topic_models\u001b[0m:\u001b[36m240\u001b[0m - \u001b[1mTraining topic model 1/1...\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:32.437\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36m_train_by_period\u001b[0m:\u001b[36m148\u001b[0m - \u001b[34m\u001b[1mProcessing period: 2025-01-06 00:00:00\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:32.437\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36m_train_by_period\u001b[0m:\u001b[36m149\u001b[0m - \u001b[34m\u001b[1mNumber of documents: 825\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:32.438\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.BERTrend\u001b[0m:\u001b[36m_train_by_period\u001b[0m:\u001b[36m151\u001b[0m - \u001b[34m\u001b[1mCreating topic model...\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:32.438\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.BERTopicModel\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m212\u001b[0m - \u001b[34m\u001b[1m\tInitializing BERTopic model\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:32.439\u001b[0m | \u001b[32m\u001b[1mSUCCESS \u001b[0m | \u001b[36mbertrend.BERTopicModel\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m222\u001b[0m - \u001b[32m\u001b[1m\tBERTopic model instance created successfully\u001b[0m\n", + "\u001b[32m2025-01-27 18:03:32.440\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mbertrend.BERTopicModel\u001b[0m:\u001b[36mfit\u001b[0m:\u001b[36m224\u001b[0m - \u001b[34m\u001b[1m\tFitting BERTopic model\u001b[0m\n", + "2025-01-27 18:03:37,067 - BERTopic - Dimensionality - Fitting the dimensionality reduction algorithm\n", + "2025-01-27 18:03:39,539 - BERTopic - Dimensionality - Completed ✓\n", + "2025-01-27 18:03:39,540 - BERTopic - Cluster - Start clustering the reduced embeddings\n", + "2025-01-27 18:03:39,574 - BERTopic - Cluster - Completed ✓\n", + "2025-01-27 18:03:39,577 - BERTopic - Representation - Extracting topics from clusters using representation models.\n" + ] + } + ], + "source": [ + "for data_file in input_data:\n", + " timestamp = pd.Timestamp(data_file.name.split('_')[0])\n", + " display(process_new_data(data_file, timestamp))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a07f5101-94ff-4e46-a808-d43873cf51fa", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "779b1acd-4270-42c4-a600-75bfeb58d20b", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b0ebd4d-c808-4880-8970-0c4c7d162a9d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 4147310..94a61cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,9 +21,11 @@ ipython = "^8.28.0" #accelerate = "^0.34.2" # required? bertopic = "0.16.2" black = "^24.10.0" +cron-descriptor = "^1.4.5" datamapplot = "0.3.0" dateparser = "^1.2.0" -dask = "2024.9.1" # issues with >=2025.x (https://github.com/dask/dask/issues/11678) +dask = "2024.12.0" # issues with >=2025.x (https://github.com/dask/dask/issues/11678) +dill = "^0.3.9" gensim = "4.3.2" hdbscan = "^0.8.40" joblib = "^1.4.2" @@ -38,6 +40,7 @@ markdown = "^3.7" nltk = "^3.9.1" numpy = "<2" openai = "^1.58.1" +opentelemetry-exporter-otlp-proto-grpc = "1.25.0" # to avoid error chroma with protobuf pandas = "^2.2.2" plotly = "^5.24.1" plotly-resampler = "^0.10.0"