Skip to content

Commit

Permalink
Improvements notebook, bug fix in merge_all_models
Browse files Browse the repository at this point in the history
  • Loading branch information
picaultj committed Jan 27, 2025
1 parent 7009d59 commit e463cf5
Show file tree
Hide file tree
Showing 2 changed files with 364 additions and 1,512 deletions.
18 changes: 13 additions & 5 deletions bertrend/BERTrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,15 +289,23 @@ def merge_all_models(
}

timestamps = sorted(topic_dfs.keys())

if len(self.topic_models)<2: # beginning of the process, no real merge needed
logger.warning("This function requires at least two topic models. Ignored")
self._are_models_merged = False
return

assert len(self.topic_models) >= 2

merged_df_without_outliers = None
all_merge_histories = []
all_new_topics = []

# TODO: tqdm
merge_df_size_over_time = []

for i, (current_timestamp, next_timestamp) in enumerate(
zip(timestamps[:-1], timestamps[1:])
for i, (current_timestamp, next_timestamp) in tqdm(
enumerate(zip(timestamps[:-1], timestamps[1:]))
):
df1 = topic_dfs[current_timestamp][
topic_dfs[current_timestamp]["Topic"] != -1
Expand All @@ -313,7 +321,7 @@ def merge_all_models(
) = _merge_models(
df1,
df2,
min_similarity=min_similarity, # SessionStateManager.get("min_similarity"),
min_similarity=min_similarity,
timestamp=current_timestamp,
)
elif not df2.empty:
Expand All @@ -324,15 +332,15 @@ def merge_all_models(
) = _merge_models(
merged_df_without_outliers,
df2,
min_similarity=min_similarity, # SessionStateManager.get("min_similarity"),
min_similarity=min_similarity,
timestamp=current_timestamp,
)
else:
continue

all_merge_histories.append(merge_history)
all_new_topics.append(new_topics)
merge_df_size_over_time = merge_df_size_over_time # SessionStateManager.get("merge_df_size_over_time")
merge_df_size_over_time = merge_df_size_over_time
merge_df_size_over_time.append(
(
current_timestamp,
Expand Down
Loading

0 comments on commit e463cf5

Please sign in to comment.