Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
grosjeang committed Feb 11, 2025
1 parent cb98592 commit ec819e2
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
15 changes: 10 additions & 5 deletions bertrend_apps/exploration/curebot/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@

# Set sidebar
with st.sidebar:
st.header("BERTopic")
with st.expander("Paramètres"):
st.checkbox("Utiliser les tags", key="use_tags", value=False)
st.slider(
st.checkbox(
"Utiliser les tags",
key="use_tags",
value=False,
help="Utiliser les tags Curebot pour orienter la recherche de sujets.",
)
st.number_input(
"Nombre d'articles minimum par sujet",
1,
2,
50,
10,
5,
key="min_articles_per_topic",
help="Permet d'influencer le nombre total de sujets trouvés par le modèle. Plus ce nombre est élevé, moins il y aura de sujets.",
)

# Create tabs
Expand Down
17 changes: 10 additions & 7 deletions bertrend_apps/exploration/curebot/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)


@st.cache_data
@st.cache_data(show_spinner=False)
def concat_data_from_files(files: list[UploadedFile]) -> pd.DataFrame:
"""
Concatenate data from multiple Excel files into a single DataFrame.
Expand All @@ -54,7 +54,7 @@ def concat_data_from_files(files: list[UploadedFile]) -> pd.DataFrame:
return df


@st.cache_data
@st.cache_data(show_spinner=False)
def chunk_df(
df: pd.DataFrame, chunk_size: int = 100, overlap: int = 20
) -> pd.DataFrame:
Expand Down Expand Up @@ -91,23 +91,26 @@ def chunk_df(
return df.copy()


@st.cache_data
@st.cache_data(show_spinner=False)
def get_embeddings(texts: list[str]) -> np.ndarray:
"""Get embeddings for a list of texts."""
return EMBEDDING_MODEL.encode(texts)


@st.cache_data
@st.cache_data(show_spinner=False)
def fit_bertopic(
docs: list[str],
embeddings: np.ndarray,
min_articles_per_topic: int,
zeroshot_topic_list: list[str] | None = None,
) -> tuple[BERTopic, list[int]]:
"""
Fit BERTopic model on a list of documents and their embeddings.
"""
# Override default parameters
bertopic_config = {"hdbscan_model": {"min_cluster_size": min_articles_per_topic}}
# Initialize topic model
topic_model = BERTopicModel()
topic_model = BERTopicModel(config=bertopic_config)

# Train topic model
topic_model_output = topic_model.fit(
Expand All @@ -121,7 +124,7 @@ def fit_bertopic(
return topic_model_output.topic_model, topic_model_output.topics


@st.cache_data
@st.cache_data(show_spinner=False)
def get_improved_topic_description(
df: pd.DataFrame, _topics_info: pd.DataFrame
) -> list[str]:
Expand Down Expand Up @@ -151,7 +154,7 @@ def get_improved_topic_description(
return improved_descriptions


@st.cache_data
@st.cache_data(show_spinner=False)
def create_newsletter(
df: pd.DataFrame,
topics_info: pd.DataFrame,
Expand Down
14 changes: 12 additions & 2 deletions bertrend_apps/exploration/curebot/tabs/tab1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ def show() -> None:

# Load data into dataframe
if st.session_state.get("uploaded_files"):
preprocess_data()
try:
preprocess_data()
except Exception as e:
st.error(
f"Erreur lors du chargement des données. Vérifiez que vos données respectent le format Curebot attendu."
)

# If data is loaded
if "df" in st.session_state:
Expand Down Expand Up @@ -121,7 +126,12 @@ def train_model() -> None:
zeroshot_topic_list = None

# Train topic model
bertopic, topics = fit_bertopic(texts_list, embeddings, zeroshot_topic_list)
bertopic, topics = fit_bertopic(
texts_list,
embeddings,
st.session_state["min_articles_per_topic"],
zeroshot_topic_list=zeroshot_topic_list,
)

# Set session_state
st.session_state["topic_model"] = bertopic
Expand Down

0 comments on commit ec819e2

Please sign in to comment.