Skip to content

Commit

Permalink
BERTrend class configuration now managed using a single toml config file
Browse files Browse the repository at this point in the history
  • Loading branch information
grosjeang committed Jan 17, 2025
1 parent 54b18c3 commit 98e4194
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 135 deletions.
67 changes: 34 additions & 33 deletions bertrend/BERTrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,16 @@
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

from bertrend import MODELS_DIR, CACHE_PATH
from bertrend import MODELS_DIR, CACHE_PATH, load_toml_config

from bertrend.topic_model.topic_model import TopicModel
from bertrend.parameters import (
DEFAULT_MIN_SIMILARITY,
DEFAULT_GRANULARITY,
DEFAULT_BERTREND_CONFIG_FILE,
DOC_INFO_DF_FILE,
TOPIC_INFO_DF_FILE,
DOC_GROUPS_FILE,
MODELS_TRAINED_FILE,
EMB_GROUPS_FILE,
GRANULARITY_FILE,
HYPERPARAMS_FILE,
BERTOPIC_SERIALIZATION,
)
Expand All @@ -52,16 +50,15 @@ class BERTrend:

def __init__(
self,
config_file: str | Path = DEFAULT_BERTREND_CONFIG_FILE,
topic_model: TopicModel = None,
zeroshot_topic_list: List[str] = None,
zeroshot_min_similarity: float = 0,
):
self.topic_model_parameters = (
TopicModel() if topic_model is None else topic_model
)
self.zeroshot_topic_list = zeroshot_topic_list
self.zeroshot_min_similarity = zeroshot_min_similarity
self.granularity = DEFAULT_GRANULARITY
# Load configuration file
self.config_file = config_file
self.config = self._load_config()

# Initialize topic model
self.topic_model = TopicModel() if topic_model is None else topic_model

# State variables of BERTrend
self._is_fitted = False
Expand Down Expand Up @@ -90,6 +87,13 @@ def __init__(
# - topic_last_update: Dictionary storing the last update timestamp of each topic.
self.topic_last_update: Dict[int, pd.Timestamp] = {}

def _load_config(self) -> dict:
"""
Load the TOML config file as a dict when instanciating the class.
"""
config = load_toml_config(self.config_file)["bertrend"]
return config

def _train_by_period(
self,
period: pd.Timestamp,
Expand Down Expand Up @@ -124,12 +128,10 @@ def _train_by_period(
logger.debug(f"Number of documents: {len(docs)}")

logger.debug("Creating topic model...")
topic_model = self.topic_model_parameters.fit(
topic_model = self.topic_model.fit(
docs=docs,
embedding_model=embedding_model,
embeddings=embeddings_subset,
zeroshot_topic_list=self.zeroshot_topic_list,
zeroshot_min_similarity=self.zeroshot_min_similarity,
).topic_model

logger.debug("Topic model created successfully")
Expand Down Expand Up @@ -212,10 +214,6 @@ def train_topic_models(
# progress_bar = st.progress(0)
# progress_text = st.empty()

logger.debug(
f"Starting to train topic models with zeroshot_topic_list: {self.zeroshot_topic_list}"
)

for i, (period, group) in enumerate(non_empty_groups):
try:
logger.info(f"Training topic model {i+1}/{len(non_empty_groups)}...")
Expand Down Expand Up @@ -250,9 +248,14 @@ def train_topic_models(

def merge_all_models(
self,
min_similarity: int = DEFAULT_MIN_SIMILARITY,
min_similarity: int | None = None,
):
"""Merge together all topic models."""
# Get default BERTrend config if argument is not provided
if min_similarity is None:
min_similarity = self.config["min_similarity"]

# Check if model is fitted
if not self._is_fitted:
raise RuntimeError("You must fit the BERTrend model before merging models.")

Expand Down Expand Up @@ -329,9 +332,8 @@ def merge_all_models(

def calculate_signal_popularity(
self,
granularity: int = DEFAULT_GRANULARITY,
decay_factor: float = 0.01,
decay_power: float = 2,
decay_factor: float | None = None,
decay_power: float | None = None,
):
"""
Compute the popularity of signals (topics) over time, accounting for merges and applying decay.
Expand All @@ -349,8 +351,13 @@ def calculate_signal_popularity(
Returns:
"""
self.granularity = granularity
# Get default BERTrend config if argument is not provided
if decay_factor is None:
decay_factor = self.config["decay_factor"]
if decay_power is None:
decay_power = self.config["decay_power"]

# Check if models are merged
if not self._are_models_merged:
raise RuntimeWarning(
"You must merge topic models first before computing signal popularity."
Expand All @@ -362,7 +369,7 @@ def calculate_signal_popularity(

min_timestamp = self.all_merge_histories_df["Timestamp"].min()
max_timestamp = self.all_merge_histories_df["Timestamp"].max()
granularity_timedelta = pd.Timedelta(days=granularity)
granularity_timedelta = pd.Timedelta(days=self.config["granularity"])
time_range = pd.date_range(
start=min_timestamp.to_pydatetime(),
end=(max_timestamp + granularity_timedelta).to_pydatetime(),
Expand Down Expand Up @@ -450,10 +457,7 @@ def save_models(self, models_path: Path = MODELS_DIR):

# Save topic model parameters
with open(CACHE_PATH / HYPERPARAMS_FILE, "wb") as f:
pickle.dump(self.topic_model_parameters, f)
# Save granularity file
with open(CACHE_PATH / GRANULARITY_FILE, "wb") as f:
pickle.dump(self.granularity, f)
pickle.dump(self.topic_model, f)
# Save doc_groups file
with open(CACHE_PATH / DOC_GROUPS_FILE, "wb") as f:
pickle.dump(self.doc_groups, f)
Expand All @@ -478,10 +482,7 @@ def restore_models(cls, models_path: Path = MODELS_DIR) -> "BERTrend":

# load topic model parameters
with open(CACHE_PATH / HYPERPARAMS_FILE, "rb") as f:
bertrend.topic_model_parameters = pickle.load(f)
# load granularity file
with open(CACHE_PATH / GRANULARITY_FILE, "rb") as f:
bertrend.granularity = pickle.load(f)
bertrend.topic_model = pickle.load(f)
# load doc_groups file
with open(CACHE_PATH / DOC_GROUPS_FILE, "rb") as f:
bertrend.doc_groups = pickle.load(f)
Expand Down
9 changes: 6 additions & 3 deletions bertrend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@

from bertrend.utils.config_utils import load_toml_config

BERTREND_DEFAULT_CONFIG_PATH = Path(__file__).parent / "bertrend.toml"
# default config files path
DEFAULT_BERTOPIC_CONFIG_FILE = (
Path(__file__).parent / "topic_model" / "topic_model_default_config.toml"
)
BERTREND_DEFAULT_CONFIG_PATH = Path(__file__).parent / "bertrend_default_config.toml"

# Read config
BERTREND_CONFIG = load_toml_config(BERTREND_DEFAULT_CONFIG_PATH)
BERTOPIC_PARAMETERS = BERTREND_CONFIG["bertopic_parameters"]
BERTREND_PARAMETERS = BERTREND_CONFIG["bertrend_parameters"]
BERTREND_PARAMETERS = BERTREND_CONFIG["bertrend"]
EMBEDDING_CONFIG = BERTREND_CONFIG["embedding_service"]
LLM_CONFIG = BERTREND_CONFIG["llm_service"]

Expand Down
28 changes: 7 additions & 21 deletions bertrend/bertrend.toml → bertrend/bertrend_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,15 @@
# SPDX-License-Identifier: MPL-2.0
# This file is part of BERTrend.

# BERTopic Hyperparameters
[bertopic_parameters]
umap_n_components = 5
umap_n_neighbors = 5
umap_min_dist = 0.0
hdbscan_min_cluster_size = 5
hdbscan_min_samples = 5
top_n_words = 10
min_df = 1
# BERTrend Hyperparameters
[bertrend]
# Data split settings
granularity = 2
# Merge model settings
min_similarity = 0.7
zeroshot_min_similarity = 0.5
bertopic_serialization = "safetensors" # or pickle
mmr_diversity = 0.3
outlier_reduction_strategy = "c-tf-idf" # or "embeddings"
# other constants
zeroshot_topics = [] # default list of topics
language="French" # or "English"
representation_models=["MaximalMarginalRelevance"] # and / or "KeyBERTInspired", "OpenAI"

# BERTrend Hyperparameters
[bertrend_parameters]
# signal classification settings
# Signal popularity settings
decay_factor = 0.01
decay_power = 2
signal_classif_lower_bound = 10
signal_classif_upper_bound = 75

Expand Down
49 changes: 23 additions & 26 deletions bertrend/demos/demos_utils/parameters_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from code_editor import code_editor

from bertrend import EMBEDDING_CONFIG
from bertrend import EMBEDDING_CONFIG, load_toml_config
from bertrend.demos.demos_utils.state_utils import (
register_widget,
save_widget_state,
Expand All @@ -15,15 +15,14 @@
reset_widget_state,
)
from bertrend.parameters import (
DEFAULT_MIN_SIMILARITY,
DEFAULT_ZEROSHOT_MIN_SIMILARITY,
EMBEDDING_DTYPES,
LANGUAGES,
ENGLISH_EMBEDDING_MODELS,
FRENCH_EMBEDDING_MODELS,
REPRESENTATION_MODELS,
MMR_REPRESENTATION_MODEL,
DEFAULT_BERTOPIC_CONFIG_FILE,
DEFAULT_BERTREND_CONFIG_FILE,
)

from bertrend.demos.demos_utils.icons import INFO_ICON
Expand Down Expand Up @@ -110,7 +109,7 @@ def display_bertopic_hyperparameters():

# Add code editor to edit the config file
st.write(INFO_ICON + " CTRL + Enter to update")
config_editor = code_editor(toml_txt, lang="yaml")
config_editor = code_editor(toml_txt, lang="toml")

# If code is edited, update config
if config_editor["text"] != "":
Expand All @@ -122,29 +121,27 @@ def display_bertopic_hyperparameters():

def display_bertrend_hyperparameters():
"""UI settings for Bertrend hyperparameters"""
with st.expander("Merging Hyperparameters", expanded=False):
register_widget("min_similarity")
st.slider(
"Minimum Similarity for Merging",
0.0,
1.0,
DEFAULT_MIN_SIMILARITY,
0.01,
key="min_similarity",
on_change=save_widget_state,
)
with st.expander("BERTrend Model Settings", expanded=False):
# Get BERTrend default configuration
with open(DEFAULT_BERTREND_CONFIG_FILE, "r") as f:
# Load default parameter the first time
toml_txt = f.read()

with st.expander("Zero-shot Parameters", expanded=False):
register_widget("zeroshot_min_similarity")
st.slider(
"Zeroshot Minimum Similarity",
0.0,
1.0,
DEFAULT_ZEROSHOT_MIN_SIMILARITY,
0.01,
key="zeroshot_min_similarity",
on_change=save_widget_state,
)
# Add code editor to edit the config file
st.write(INFO_ICON + " CTRL + Enter to update")
config_editor = code_editor(toml_txt, lang="toml")

# If code is edited, update config
if config_editor["text"] != "":
st.session_state["bertrend_config"] = config_editor["text"]
# Else use default config
else:
st.session_state["bertrend_config"] = toml_txt

# Save granularity in session state as it is re-used in other components
st.session_state["granularity"] = load_toml_config(
st.session_state["bertrend_config"]
)["bertrend"]["granularity"]


def display_representation_model_options():
Expand Down
49 changes: 8 additions & 41 deletions bertrend/demos/weak_signals/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
import plotly.graph_objects as go
from loguru import logger

from bertrend import (
ZEROSHOT_TOPICS_DATA_DIR,
CACHE_PATH,
)
from bertrend import ZEROSHOT_TOPICS_DATA_DIR, CACHE_PATH
from bertrend.BERTrend import BERTrend
from bertrend.demos.demos_utils import is_admin_mode
from bertrend.demos.demos_utils.data_loading_component import (
Expand Down Expand Up @@ -181,21 +178,12 @@ def training_page():
st.warning(NO_EMBEDDINGS_WARNING_MESSAGE, icon=WARNING_ICON)
st.stop()

# Select granularity
st.number_input(
"Select Granularity",
value=DEFAULT_GRANULARITY,
min_value=1,
max_value=30,
key="granularity_select",
help="Number of days to split the data by",
)

# Show documents per grouped timestamp
with st.expander("Documents per Timestamp", expanded=True):
st.write(f"Granularity: {st.session_state['granularity']}")
grouped_data = group_by_days(
SessionStateManager.get_dataframe("time_filtered_df"),
day_granularity=SessionStateManager.get("granularity_select"),
day_granularity=st.session_state["granularity"],
)
non_empty_timestamps = [
timestamp for timestamp, group in grouped_data.items() if not group.empty
Expand Down Expand Up @@ -237,35 +225,16 @@ def training_page():
# FIXME: called twice (see above)
grouped_data = group_by_days(
SessionStateManager.get_dataframe("time_filtered_df"),
day_granularity=SessionStateManager.get("granularity_select"),
day_granularity=st.session_state["granularity"],
)

# Initialize topic model
topic_model = TopicModel(
umap_n_components=SessionStateManager.get("umap_n_components"),
umap_n_neighbors=SessionStateManager.get("umap_n_neighbors"),
hdbscan_min_cluster_size=SessionStateManager.get(
"hdbscan_min_cluster_size"
),
hdbscan_min_samples=SessionStateManager.get("hdbscan_min_samples"),
hdbscan_cluster_selection_method=SessionStateManager.get(
"hdbscan_cluster_selection_method"
),
vectorizer_ngram_range=SessionStateManager.get(
"vectorizer_ngram_range"
),
min_df=SessionStateManager.get("min_df"),
top_n_words=SessionStateManager.get("top_n_words"),
language=SessionStateManager.get("language"),
)
topic_model = TopicModel(st.session_state["bertopic_config"])

# Created BERTrend object
bertrend = BERTrend(
config_file=st.session_state["bertrend_config"],
topic_model=topic_model,
zeroshot_topic_list=zeroshot_topic_list,
zeroshot_min_similarity=SessionStateManager.get(
"zeroshot_min_similarity"
),
)
# Train topic models on data
bertrend.train_topic_models(
Expand Down Expand Up @@ -295,9 +264,7 @@ def training_page():
min_similarity=SessionStateManager.get("min_similarity"),
)

bertrend.calculate_signal_popularity(
granularity=SessionStateManager.get("granularity_select"),
)
bertrend.calculate_signal_popularity()

SessionStateManager.set("popularity_computed", True)

Expand Down Expand Up @@ -346,7 +313,7 @@ def analysis_page():
weak_signal_trends = detect_weak_signals_zeroshot(
topic_models,
zeroshot_topic_list,
SessionStateManager.get("granularity_select"),
st.session_state["granularity"],
)
with st.expander("Zero-shot Weak Signal Trends", expanded=False):
fig_trend = go.Figure()
Expand Down
Loading

0 comments on commit 98e4194

Please sign in to comment.