Skip to content

Commit

Permalink
Updated train topics function for incremental topic learning
Browse files Browse the repository at this point in the history
Updated notebook
  • Loading branch information
picaultj committed Jan 27, 2025
1 parent 8d6b632 commit 7009d59
Show file tree
Hide file tree
Showing 4 changed files with 2,063 additions and 9 deletions.
7 changes: 4 additions & 3 deletions bertrend/BERTrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,13 @@ def train_topic_models(

self._is_fitted = True

# Merge the newly obtained topic models with new ones
# Update topic_models: Dictionary of trained BERTopic models for each timestamp.
self.topic_models = topic_models
self.topic_models.update(topic_models)
# Update doc_groups: Dictionary of document groups for each timestamp.
self.doc_groups = doc_groups
self.doc_groups.update(doc_groups)
# Update emb_groups: Dictionary of document embeddings for each timestamp.
self.emb_groups = emb_groups
self.emb_groups.update(emb_groups)
logger.success("Finished training all topic models")

def merge_all_models(
Expand Down
11 changes: 6 additions & 5 deletions bertrend/trend_analysis/weak_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,32 @@
import scipy
from bertopic import BERTopic
from loguru import logger
from pandas import Timestamp

from bertrend.llm_utils.openai_client import OpenAI_Client
from bertrend import LLM_CONFIG
from bertrend.trend_analysis.prompts import get_prompt, save_html_output


def detect_weak_signals_zeroshot(
topic_models: dict[pd.Timestamp, BERTopic],
topic_models: dict[Timestamp, BERTopic],
zeroshot_topic_list: list[str],
granularity: int,
decay_factor: float = 0.01,
decay_power: float = 2,
) -> dict[str, dict[pd.Timestamp, dict[str, any]]]:
) -> dict[str, dict[Timestamp, dict[str, any]]]:
"""
Detect weak signals based on the zero-shot list of topics to monitor.
Args:
topic_models (Dict[pd.Timestamp, BERTopic]): Dictionary of BERTopic models for each timestamp.
topic_models (Dict[Timestamp, BERTopic]): Dictionary of BERTopic models for each timestamp.
zeroshot_topic_list (List[str]): List of topics to monitor for weak signals.
granularity (int): The granularity of the timestamps in days.
decay_factor (float): The decay factor for exponential decay.
decay_power (float): The decay power for exponential decay.
Returns:
Dict[str, Dict[pd.Timestamp, Dict[str, any]]]: Dictionary of weak signal trends for each monitored topic.
Dict[str, Dict[Timestamp, Dict[str, any]]]: Dictionary of weak signal trends for each monitored topic.
"""
weak_signal_trends = {}

Expand Down Expand Up @@ -329,7 +330,7 @@ def _apply_decay_to_inactive_topics(
topic_last_popularity[topic] = decayed_popularity


def analyze_signal(bertrend, topic_number: int, current_date):
def analyze_signal(bertrend, topic_number: int, current_date: Timestamp):
topic_merge_rows = bertrend.all_merge_histories_df[
bertrend.all_merge_histories_df["Topic1"] == topic_number
].sort_values("Timestamp")
Expand Down
Loading

0 comments on commit 7009d59

Please sign in to comment.