Skip to content

Commit

Permalink
[Refactor] Simplification of demo app states
Browse files Browse the repository at this point in the history
  • Loading branch information
picaultj committed Dec 18, 2024
1 parent 08752e3 commit ba8869b
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 126 deletions.
131 changes: 67 additions & 64 deletions bertrend/BERTrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,37 @@ def __init__(
self.zeroshot_topic_list = zeroshot_topic_list
self.zeroshot_min_similarity = zeroshot_min_similarity

# Variables related to time-based topic models
self.emb_groups = None
self.doc_groups = None
self.topic_models = None

# State variables
self.merged_topics = None
# State variables of BERTrend
self._is_fitted = False
self._are_models_merged = False

# Variables related to time-based topic models
# - topic_models: Dictionary of trained BERTopic models for each timestamp.
self.topic_models: Dict[pd.Timestamp, BERTopic] = {}
# - doc_groups: Dictionary of document groups for each timestamp.
self.doc_groups: Dict[pd.Timestamp, List[str]] = {}
# - emb_groups: Dictionary of document embeddings for each timestamp.
self.emb_groups: Dict[pd.Timestamp, np.ndarray] = {}

# Variables containing info about merged topics
self.all_new_topics_df = None
self.all_merge_histories_df = None
self.merged_df = None

# Variables containing info about topic popularity
# - topic_sizes: Dictionary storing topic sizes and related information over time.
self.topic_sizes: Dict[int, Dict[str, Any]] = defaultdict(
lambda: defaultdict(list)
)
# - topic_last_popularity: Dictionary storing the last known popularity of each topic.
self.topic_last_popularity: Dict[int, float] = {}
# - topic_last_update: Dictionary storing the last update timestamp of each topic.
self.topic_last_update: Dict[int, pd.Timestamp] = {}

def _train_by_period(
self,
period: int,
group: Dict[pd.Timestamp, pd.DataFrame],
period: pd.Timestamp,
group: pd.DataFrame,
embedding_model: SentenceTransformer,
embeddings: np.ndarray,
) -> Tuple[
Expand All @@ -73,14 +91,14 @@ def _train_by_period(
Train BERTopic models for a given time period from the grouped data.
Args:
period: int (int): indice of the time period
grouped_data (Dict[pd.Timestamp, pd.DataFrame]): Dictionary of grouped data by timestamp.
period (pd.Timestamp): Timestamp of the time period
group (pd.DataFrame): Group of data associated to that timestamp.
embedding_model (SentenceTransformer): Sentence transformer model for embeddings.
embeddings (np.ndarray): Precomputed document embeddings.
Returns:
Tuple[BERTopic, List[str], np.ndarray]]:
Tuple[BERTopic, List[str], np.ndarray]:
- topic_model: trained BERTopic models for this period.
- doc_group: document groups for this period.
- emb_group: document embeddings for this period.
Expand Down Expand Up @@ -153,25 +171,19 @@ def train_topic_models(
grouped_data: Dict[pd.Timestamp, pd.DataFrame],
embedding_model: SentenceTransformer,
embeddings: np.ndarray,
) -> Tuple[
Dict[pd.Timestamp, BERTopic],
Dict[pd.Timestamp, List[str]],
Dict[pd.Timestamp, np.ndarray],
]:
):
"""
Train BERTopic models for each timestamp in the grouped data.
Stores Tuple[Dict[pd.Timestamp, BERTopic], Dict[pd.Timestamp, List[str]], Dict[pd.Timestamp, np.ndarray]]:
- topic_models: Dictionary of trained BERTopic models for each timestamp.
- doc_groups: Dictionary of document groups for each timestamp.
- emb_groups: Dictionary of document embeddings for each timestamp.
Args:
grouped_data (Dict[pd.Timestamp, pd.DataFrame]): Dictionary of grouped data by timestamp.
embedding_model (SentenceTransformer): Sentence transformer model for embeddings.
embeddings (np.ndarray): Precomputed document embeddings.
Returns:
Tuple[Dict[pd.Timestamp, BERTopic], Dict[pd.Timestamp, List[str]], Dict[pd.Timestamp, np.ndarray]]:
- topic_models: Dictionary of trained BERTopic models for each timestamp.
- doc_groups: Dictionary of document groups for each timestamp.
- emb_groups: Dictionary of document embeddings for each timestamp.
"""
# TODO from topic_modelling = train_topic_models (modulo data transformation)
# TODO rename to fit?
Expand Down Expand Up @@ -212,20 +224,21 @@ def train_topic_models(
)
"""

logger.debug("Finished training all topic models")
self._is_fitted = True

# Update topic_models: Dictionary of trained BERTopic models for each timestamp.
self.topic_models = topic_models
# Update doc_groups: Dictionary of document groups for each timestamp.
self.doc_groups = doc_groups
# Update emb_groups: Dictionary of document embeddings for each timestamp.
self.emb_groups = emb_groups

return topic_models, doc_groups, emb_groups
logger.debug("Finished training all topic models")

def merge_models(
self,
min_similarity: int = DEFAULT_MIN_SIMILARITY,
granularity: int = DEFAULT_GRANULARITY,
):
# TODO: add func description
if not self._is_fitted:
raise RuntimeError("You must fit the BERTrend model before merging models.")

Expand Down Expand Up @@ -298,40 +311,25 @@ def merge_models(
all_merge_histories_df = pd.concat(all_merge_histories, ignore_index=True)
all_new_topics_df = pd.concat(all_new_topics, ignore_index=True)

# SessionStateManager.set_multiple(
# merged_df=merged_df_without_outliers,
# all_merge_histories_df=all_merge_histories_df,
# all_new_topics_df=all_new_topics_df,
# )

# SessionStateManager.set("models_merged", True)
self.model_merged = True

return merged_df_without_outliers, all_merge_histories_df, all_new_topics_df
self.merged_df = merged_df_without_outliers
self.all_merge_histories_df = all_merge_histories_df
self.all_new_topics_df = all_new_topics_df

(
topic_sizes,
topic_last_popularity,
topic_last_update,
) = calculate_signal_popularity(all_merge_histories_df, granularity)
# SessionStateManager.set_multiple(
# topic_sizes=topic_sizes,
# topic_last_popularity=topic_last_popularity,
# topic_last_update=topic_last_update,
# )

# TODO: update / set local vars
self._are_models_merged = True

# TODO: avoid parameter passing, use internal vars instead
def calculate_signal_popularity(
self,
all_merge_histories_df: pd.DataFrame,
granularity: int,
granularity: int = DEFAULT_GRANULARITY,
decay_factor: float = 0.01,
decay_power: float = 2,
) -> Tuple[Dict[int, Dict[str, Any]], Dict[int, float], Dict[int, pd.Timestamp]]:
):
"""
Calculate the popularity of signals (topics) over time, accounting for merges and applying decay.
Updates:
- topic_sizes (Dict[int, Dict[str, Any]]): Dictionary storing topic sizes and related information over time.
- topic_last_popularity (Dict[int, float]): Dictionary storing the last known popularity of each topic.
- topic_last_update (Dict[int, pd.Timestamp]]): Dictionary storing the last update timestamp of each topic.
Args:
all_merge_histories_df (pd.DataFrame): DataFrame containing all merge histories.
Expand All @@ -340,17 +338,20 @@ def calculate_signal_popularity(
decay_power (float): Power for exponential decay calculation.
Returns:
Tuple[Dict[int, Dict[str, Any]], Dict[int, float], Dict[int, pd.Timestamp]]:
- topic_sizes: Dictionary storing topic sizes and related information over time.
- topic_last_popularity: Dictionary storing the last known popularity of each topic.
- topic_last_update: Dictionary storing the last update timestamp of each topic.
"""
if not self._are_models_merged:
# FIXME: RuntimeError
raise RuntimeError(
"You must merge topic models first before computing signal popularity."
)

topic_sizes = defaultdict(lambda: defaultdict(list))
topic_last_popularity = {}
topic_last_update = {}

min_timestamp = all_merge_histories_df["Timestamp"].min()
max_timestamp = all_merge_histories_df["Timestamp"].max()
min_timestamp = self.all_merge_histories_df["Timestamp"].min()
max_timestamp = self.all_merge_histories_df["Timestamp"].max()
granularity_timedelta = pd.Timedelta(days=granularity)
time_range = pd.date_range(
start=min_timestamp.to_pydatetime(),
Expand All @@ -359,8 +360,8 @@ def calculate_signal_popularity(
)

for current_timestamp in time_range:
current_df = all_merge_histories_df[
all_merge_histories_df["Timestamp"] == current_timestamp
current_df = self.all_merge_histories_df[
self.all_merge_histories_df["Timestamp"] == current_timestamp
]
updated_topics = set()

Expand Down Expand Up @@ -396,8 +397,8 @@ def calculate_signal_popularity(

# Apply decay to topics that weren't updated in this timestamp or the next
next_timestamp = current_timestamp + granularity_timedelta
next_df = all_merge_histories_df[
all_merge_histories_df["Timestamp"] == next_timestamp
next_df = self.all_merge_histories_df[
self.all_merge_histories_df["Timestamp"] == next_timestamp
]
topics_updated_next = set(next_df["Topic1"])

Expand All @@ -413,7 +414,9 @@ def calculate_signal_popularity(
decay_power,
)

return topic_sizes, topic_last_popularity, topic_last_update
self.topic_sizes = topic_sizes
self.topic_last_popularity = topic_last_popularity
self.topic_last_update = topic_last_update

#####################################################################################################
# FIXME: WIP
Expand Down
Loading

0 comments on commit ba8869b

Please sign in to comment.