From 98e41940f7c0bfb662d6a01fdb2b0d416d5b296a Mon Sep 17 00:00:00 2001 From: Guillaume Grosjean Date: Fri, 17 Jan 2025 10:45:19 +0100 Subject: [PATCH] BERTrend class configuration now managed using a single toml config file --- bertrend/BERTrend.py | 67 ++++++++++--------- bertrend/__init__.py | 9 ++- ...rend.toml => bertrend_default_config.toml} | 28 ++------ .../demos/demos_utils/parameters_component.py | 49 +++++++------- bertrend/demos/weak_signals/app.py | 49 +++----------- .../weak_signals/visualizations_utils.py | 15 +++-- bertrend/parameters.py | 9 ++- 7 files changed, 91 insertions(+), 135 deletions(-) rename bertrend/{bertrend.toml => bertrend_default_config.toml} (71%) diff --git a/bertrend/BERTrend.py b/bertrend/BERTrend.py index f7e13fe..42454e8 100644 --- a/bertrend/BERTrend.py +++ b/bertrend/BERTrend.py @@ -15,18 +15,16 @@ from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity -from bertrend import MODELS_DIR, CACHE_PATH +from bertrend import MODELS_DIR, CACHE_PATH, load_toml_config from bertrend.topic_model.topic_model import TopicModel from bertrend.parameters import ( - DEFAULT_MIN_SIMILARITY, - DEFAULT_GRANULARITY, + DEFAULT_BERTREND_CONFIG_FILE, DOC_INFO_DF_FILE, TOPIC_INFO_DF_FILE, DOC_GROUPS_FILE, MODELS_TRAINED_FILE, EMB_GROUPS_FILE, - GRANULARITY_FILE, HYPERPARAMS_FILE, BERTOPIC_SERIALIZATION, ) @@ -52,16 +50,15 @@ class BERTrend: def __init__( self, + config_file: str | Path = DEFAULT_BERTREND_CONFIG_FILE, topic_model: TopicModel = None, - zeroshot_topic_list: List[str] = None, - zeroshot_min_similarity: float = 0, ): - self.topic_model_parameters = ( - TopicModel() if topic_model is None else topic_model - ) - self.zeroshot_topic_list = zeroshot_topic_list - self.zeroshot_min_similarity = zeroshot_min_similarity - self.granularity = DEFAULT_GRANULARITY + # Load configuration file + self.config_file = config_file + self.config = self._load_config() + + # Initialize topic model + self.topic_model = TopicModel() if topic_model is None else topic_model # State variables of BERTrend self._is_fitted = False @@ -90,6 +87,13 @@ def __init__( # - topic_last_update: Dictionary storing the last update timestamp of each topic. self.topic_last_update: Dict[int, pd.Timestamp] = {} + def _load_config(self) -> dict: + """ + Load the TOML config file as a dict when instanciating the class. + """ + config = load_toml_config(self.config_file)["bertrend"] + return config + def _train_by_period( self, period: pd.Timestamp, @@ -124,12 +128,10 @@ def _train_by_period( logger.debug(f"Number of documents: {len(docs)}") logger.debug("Creating topic model...") - topic_model = self.topic_model_parameters.fit( + topic_model = self.topic_model.fit( docs=docs, embedding_model=embedding_model, embeddings=embeddings_subset, - zeroshot_topic_list=self.zeroshot_topic_list, - zeroshot_min_similarity=self.zeroshot_min_similarity, ).topic_model logger.debug("Topic model created successfully") @@ -212,10 +214,6 @@ def train_topic_models( # progress_bar = st.progress(0) # progress_text = st.empty() - logger.debug( - f"Starting to train topic models with zeroshot_topic_list: {self.zeroshot_topic_list}" - ) - for i, (period, group) in enumerate(non_empty_groups): try: logger.info(f"Training topic model {i+1}/{len(non_empty_groups)}...") @@ -250,9 +248,14 @@ def train_topic_models( def merge_all_models( self, - min_similarity: int = DEFAULT_MIN_SIMILARITY, + min_similarity: int | None = None, ): """Merge together all topic models.""" + # Get default BERTrend config if argument is not provided + if min_similarity is None: + min_similarity = self.config["min_similarity"] + + # Check if model is fitted if not self._is_fitted: raise RuntimeError("You must fit the BERTrend model before merging models.") @@ -329,9 +332,8 @@ def merge_all_models( def calculate_signal_popularity( self, - granularity: int = DEFAULT_GRANULARITY, - decay_factor: float = 0.01, - decay_power: float = 2, + decay_factor: float | None = None, + decay_power: float | None = None, ): """ Compute the popularity of signals (topics) over time, accounting for merges and applying decay. @@ -349,8 +351,13 @@ def calculate_signal_popularity( Returns: """ - self.granularity = granularity + # Get default BERTrend config if argument is not provided + if decay_factor is None: + decay_factor = self.config["decay_factor"] + if decay_power is None: + decay_power = self.config["decay_power"] + # Check if models are merged if not self._are_models_merged: raise RuntimeWarning( "You must merge topic models first before computing signal popularity." @@ -362,7 +369,7 @@ def calculate_signal_popularity( min_timestamp = self.all_merge_histories_df["Timestamp"].min() max_timestamp = self.all_merge_histories_df["Timestamp"].max() - granularity_timedelta = pd.Timedelta(days=granularity) + granularity_timedelta = pd.Timedelta(days=self.config["granularity"]) time_range = pd.date_range( start=min_timestamp.to_pydatetime(), end=(max_timestamp + granularity_timedelta).to_pydatetime(), @@ -450,10 +457,7 @@ def save_models(self, models_path: Path = MODELS_DIR): # Save topic model parameters with open(CACHE_PATH / HYPERPARAMS_FILE, "wb") as f: - pickle.dump(self.topic_model_parameters, f) - # Save granularity file - with open(CACHE_PATH / GRANULARITY_FILE, "wb") as f: - pickle.dump(self.granularity, f) + pickle.dump(self.topic_model, f) # Save doc_groups file with open(CACHE_PATH / DOC_GROUPS_FILE, "wb") as f: pickle.dump(self.doc_groups, f) @@ -478,10 +482,7 @@ def restore_models(cls, models_path: Path = MODELS_DIR) -> "BERTrend": # load topic model parameters with open(CACHE_PATH / HYPERPARAMS_FILE, "rb") as f: - bertrend.topic_model_parameters = pickle.load(f) - # load granularity file - with open(CACHE_PATH / GRANULARITY_FILE, "rb") as f: - bertrend.granularity = pickle.load(f) + bertrend.topic_model = pickle.load(f) # load doc_groups file with open(CACHE_PATH / DOC_GROUPS_FILE, "rb") as f: bertrend.doc_groups = pickle.load(f) diff --git a/bertrend/__init__.py b/bertrend/__init__.py index d6ee0de..18308d9 100644 --- a/bertrend/__init__.py +++ b/bertrend/__init__.py @@ -7,12 +7,15 @@ from bertrend.utils.config_utils import load_toml_config -BERTREND_DEFAULT_CONFIG_PATH = Path(__file__).parent / "bertrend.toml" +# default config files path +DEFAULT_BERTOPIC_CONFIG_FILE = ( + Path(__file__).parent / "topic_model" / "topic_model_default_config.toml" +) +BERTREND_DEFAULT_CONFIG_PATH = Path(__file__).parent / "bertrend_default_config.toml" # Read config BERTREND_CONFIG = load_toml_config(BERTREND_DEFAULT_CONFIG_PATH) -BERTOPIC_PARAMETERS = BERTREND_CONFIG["bertopic_parameters"] -BERTREND_PARAMETERS = BERTREND_CONFIG["bertrend_parameters"] +BERTREND_PARAMETERS = BERTREND_CONFIG["bertrend"] EMBEDDING_CONFIG = BERTREND_CONFIG["embedding_service"] LLM_CONFIG = BERTREND_CONFIG["llm_service"] diff --git a/bertrend/bertrend.toml b/bertrend/bertrend_default_config.toml similarity index 71% rename from bertrend/bertrend.toml rename to bertrend/bertrend_default_config.toml index 35a8a61..a781262 100644 --- a/bertrend/bertrend.toml +++ b/bertrend/bertrend_default_config.toml @@ -3,29 +3,15 @@ # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. -# BERTopic Hyperparameters -[bertopic_parameters] -umap_n_components = 5 -umap_n_neighbors = 5 -umap_min_dist = 0.0 -hdbscan_min_cluster_size = 5 -hdbscan_min_samples = 5 -top_n_words = 10 -min_df = 1 +# BERTrend Hyperparameters +[bertrend] +# Data split settings granularity = 2 +# Merge model settings min_similarity = 0.7 -zeroshot_min_similarity = 0.5 -bertopic_serialization = "safetensors" # or pickle -mmr_diversity = 0.3 -outlier_reduction_strategy = "c-tf-idf" # or "embeddings" -# other constants -zeroshot_topics = [] # default list of topics -language="French" # or "English" -representation_models=["MaximalMarginalRelevance"] # and / or "KeyBERTInspired", "OpenAI" - -# BERTrend Hyperparameters -[bertrend_parameters] -# signal classification settings +# Signal popularity settings +decay_factor = 0.01 +decay_power = 2 signal_classif_lower_bound = 10 signal_classif_upper_bound = 75 diff --git a/bertrend/demos/demos_utils/parameters_component.py b/bertrend/demos/demos_utils/parameters_component.py index e6a6904..68a150c 100644 --- a/bertrend/demos/demos_utils/parameters_component.py +++ b/bertrend/demos/demos_utils/parameters_component.py @@ -6,7 +6,7 @@ from code_editor import code_editor -from bertrend import EMBEDDING_CONFIG +from bertrend import EMBEDDING_CONFIG, load_toml_config from bertrend.demos.demos_utils.state_utils import ( register_widget, save_widget_state, @@ -15,8 +15,6 @@ reset_widget_state, ) from bertrend.parameters import ( - DEFAULT_MIN_SIMILARITY, - DEFAULT_ZEROSHOT_MIN_SIMILARITY, EMBEDDING_DTYPES, LANGUAGES, ENGLISH_EMBEDDING_MODELS, @@ -24,6 +22,7 @@ REPRESENTATION_MODELS, MMR_REPRESENTATION_MODEL, DEFAULT_BERTOPIC_CONFIG_FILE, + DEFAULT_BERTREND_CONFIG_FILE, ) from bertrend.demos.demos_utils.icons import INFO_ICON @@ -110,7 +109,7 @@ def display_bertopic_hyperparameters(): # Add code editor to edit the config file st.write(INFO_ICON + " CTRL + Enter to update") - config_editor = code_editor(toml_txt, lang="yaml") + config_editor = code_editor(toml_txt, lang="toml") # If code is edited, update config if config_editor["text"] != "": @@ -122,29 +121,27 @@ def display_bertopic_hyperparameters(): def display_bertrend_hyperparameters(): """UI settings for Bertrend hyperparameters""" - with st.expander("Merging Hyperparameters", expanded=False): - register_widget("min_similarity") - st.slider( - "Minimum Similarity for Merging", - 0.0, - 1.0, - DEFAULT_MIN_SIMILARITY, - 0.01, - key="min_similarity", - on_change=save_widget_state, - ) + with st.expander("BERTrend Model Settings", expanded=False): + # Get BERTrend default configuration + with open(DEFAULT_BERTREND_CONFIG_FILE, "r") as f: + # Load default parameter the first time + toml_txt = f.read() - with st.expander("Zero-shot Parameters", expanded=False): - register_widget("zeroshot_min_similarity") - st.slider( - "Zeroshot Minimum Similarity", - 0.0, - 1.0, - DEFAULT_ZEROSHOT_MIN_SIMILARITY, - 0.01, - key="zeroshot_min_similarity", - on_change=save_widget_state, - ) + # Add code editor to edit the config file + st.write(INFO_ICON + " CTRL + Enter to update") + config_editor = code_editor(toml_txt, lang="toml") + + # If code is edited, update config + if config_editor["text"] != "": + st.session_state["bertrend_config"] = config_editor["text"] + # Else use default config + else: + st.session_state["bertrend_config"] = toml_txt + + # Save granularity in session state as it is re-used in other components + st.session_state["granularity"] = load_toml_config( + st.session_state["bertrend_config"] + )["bertrend"]["granularity"] def display_representation_model_options(): diff --git a/bertrend/demos/weak_signals/app.py b/bertrend/demos/weak_signals/app.py index 352f3f5..7bb79c1 100644 --- a/bertrend/demos/weak_signals/app.py +++ b/bertrend/demos/weak_signals/app.py @@ -13,10 +13,7 @@ import plotly.graph_objects as go from loguru import logger -from bertrend import ( - ZEROSHOT_TOPICS_DATA_DIR, - CACHE_PATH, -) +from bertrend import ZEROSHOT_TOPICS_DATA_DIR, CACHE_PATH from bertrend.BERTrend import BERTrend from bertrend.demos.demos_utils import is_admin_mode from bertrend.demos.demos_utils.data_loading_component import ( @@ -181,21 +178,12 @@ def training_page(): st.warning(NO_EMBEDDINGS_WARNING_MESSAGE, icon=WARNING_ICON) st.stop() - # Select granularity - st.number_input( - "Select Granularity", - value=DEFAULT_GRANULARITY, - min_value=1, - max_value=30, - key="granularity_select", - help="Number of days to split the data by", - ) - # Show documents per grouped timestamp with st.expander("Documents per Timestamp", expanded=True): + st.write(f"Granularity: {st.session_state['granularity']}") grouped_data = group_by_days( SessionStateManager.get_dataframe("time_filtered_df"), - day_granularity=SessionStateManager.get("granularity_select"), + day_granularity=st.session_state["granularity"], ) non_empty_timestamps = [ timestamp for timestamp, group in grouped_data.items() if not group.empty @@ -237,35 +225,16 @@ def training_page(): # FIXME: called twice (see above) grouped_data = group_by_days( SessionStateManager.get_dataframe("time_filtered_df"), - day_granularity=SessionStateManager.get("granularity_select"), + day_granularity=st.session_state["granularity"], ) # Initialize topic model - topic_model = TopicModel( - umap_n_components=SessionStateManager.get("umap_n_components"), - umap_n_neighbors=SessionStateManager.get("umap_n_neighbors"), - hdbscan_min_cluster_size=SessionStateManager.get( - "hdbscan_min_cluster_size" - ), - hdbscan_min_samples=SessionStateManager.get("hdbscan_min_samples"), - hdbscan_cluster_selection_method=SessionStateManager.get( - "hdbscan_cluster_selection_method" - ), - vectorizer_ngram_range=SessionStateManager.get( - "vectorizer_ngram_range" - ), - min_df=SessionStateManager.get("min_df"), - top_n_words=SessionStateManager.get("top_n_words"), - language=SessionStateManager.get("language"), - ) + topic_model = TopicModel(st.session_state["bertopic_config"]) # Created BERTrend object bertrend = BERTrend( + config_file=st.session_state["bertrend_config"], topic_model=topic_model, - zeroshot_topic_list=zeroshot_topic_list, - zeroshot_min_similarity=SessionStateManager.get( - "zeroshot_min_similarity" - ), ) # Train topic models on data bertrend.train_topic_models( @@ -295,9 +264,7 @@ def training_page(): min_similarity=SessionStateManager.get("min_similarity"), ) - bertrend.calculate_signal_popularity( - granularity=SessionStateManager.get("granularity_select"), - ) + bertrend.calculate_signal_popularity() SessionStateManager.set("popularity_computed", True) @@ -346,7 +313,7 @@ def analysis_page(): weak_signal_trends = detect_weak_signals_zeroshot( topic_models, zeroshot_topic_list, - SessionStateManager.get("granularity_select"), + st.session_state["granularity"], ) with st.expander("Zero-shot Weak Signal Trends", expanded=False): fig_trend = go.Figure() diff --git a/bertrend/demos/weak_signals/visualizations_utils.py b/bertrend/demos/weak_signals/visualizations_utils.py index 8c5107e..eb77797 100644 --- a/bertrend/demos/weak_signals/visualizations_utils.py +++ b/bertrend/demos/weak_signals/visualizations_utils.py @@ -149,12 +149,15 @@ def display_popularity_evolution(): min_datetime = all_merge_histories_df["Timestamp"].min().to_pydatetime() max_datetime = all_merge_histories_df["Timestamp"].max().to_pydatetime() + # Get granularity + granularity = st.session_state["granularity"] + # Slider to select the date current_date = st.slider( "Current date", min_value=min_datetime, max_value=max_datetime, - step=pd.Timedelta(days=SessionStateManager.get("granularity_select")), + step=pd.Timedelta(days=granularity), format="YYYY-MM-DD", help="""The earliest selectable date corresponds to the earliest timestamp when topics were merged (with the smallest possible value being the earliest timestamp in the provided data). @@ -163,8 +166,6 @@ def display_popularity_evolution(): key="current_date", ) - granularity = SessionStateManager.get("granularity_select") - # Compute threshold values window_start, window_end, all_popularity_values, q1, q3 = ( compute_popularity_values_and_thresholds( @@ -205,6 +206,7 @@ def display_popularity_evolution(): def save_signal_evolution(): """Save Signal Evolution Data to investigate later on in a separate notebook""" bertrend = SessionStateManager.get("bertrend") + granularity = SessionStateManager.get("granularity") all_merge_histories_df = bertrend.all_merge_histories_df min_datetime = all_merge_histories_df["Timestamp"].min().to_pydatetime() max_datetime = all_merge_histories_df["Timestamp"].max().to_pydatetime() @@ -215,7 +217,7 @@ def save_signal_evolution(): options=pd.date_range( start=min_datetime, end=max_datetime, - freq=pd.Timedelta(days=SessionStateManager.get("granularity_select")), + freq=pd.Timedelta(days=granularity), ), value=(min_datetime, max_datetime), format_func=lambda x: x.strftime("%Y-%m-%d"), @@ -229,7 +231,7 @@ def save_signal_evolution(): topic_last_popularity=bertrend.topic_last_popularity, topic_last_update=bertrend.topic_last_update, window_size=SessionStateManager.get("window_size"), - granularity=SessionStateManager.get("granularity_select"), + granularity=granularity, start_timestamp=pd.Timestamp(start_date), end_timestamp=pd.Timestamp(end_date), ) @@ -302,6 +304,7 @@ def display_signal_analysis( """Display a LLM-based analyis of a specific topic.""" language = SessionStateManager.get("language") bertrend = SessionStateManager.get("bertrend") + granularity = SessionStateManager.get("granularity") all_merge_histories_df = bertrend.all_merge_histories_df st.subheader("Signal Interpretation") @@ -310,7 +313,7 @@ def display_signal_analysis( topic_number, SessionStateManager.get("current_date"), all_merge_histories_df, - SessionStateManager.get("granularity_select"), + granularity, language, ) diff --git a/bertrend/parameters.py b/bertrend/parameters.py index 329eafe..e105baa 100644 --- a/bertrend/parameters.py +++ b/bertrend/parameters.py @@ -8,7 +8,7 @@ import torch -from bertrend import BERTOPIC_PARAMETERS, BERTREND_PARAMETERS +from bertrend import BERTREND_PARAMETERS stopwords_en_file = Path(__file__).parent / "resources" / "stopwords-en.json" stopwords_fr_file = Path(__file__).parent / "resources" / "stopwords-fr.json" @@ -63,10 +63,7 @@ DEFAULT_BERTOPIC_CONFIG_FILE = ( Path(__file__).parent / "topic_model" / "topic_model_default_config.toml" ) -DEFAULT_GRANULARITY = BERTOPIC_PARAMETERS["granularity"] -DEFAULT_MIN_SIMILARITY = BERTOPIC_PARAMETERS["min_similarity"] -DEFAULT_ZEROSHOT_MIN_SIMILARITY = BERTOPIC_PARAMETERS["zeroshot_min_similarity"] -BERTOPIC_SERIALIZATION = BERTOPIC_PARAMETERS["bertopic_serialization"] +BERTOPIC_SERIALIZATION = "safetensors" # or pickle LANGUAGES = ["French", "English"] REPRESENTATION_MODELS = [ MMR_REPRESENTATION_MODEL, @@ -75,6 +72,8 @@ ] # BERTrend parameters +DEFAULT_BERTREND_CONFIG_FILE = Path(__file__).parent / "bertrend_default_config.toml" + # Signal classification Settings SIGNAL_CLASSIF_LOWER_BOUND = BERTREND_PARAMETERS["signal_classif_lower_bound"] SIGNAL_CLASSIF_UPPER_BOUND = BERTREND_PARAMETERS["signal_classif_upper_bound"]