From 520201925f110769f6d03cb8061ba47b40f22a62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Picault?= Date: Mon, 30 Dec 2024 17:06:48 +0100 Subject: [PATCH] debug --- bertrend/demos/demos_utils/data_loading_component.py | 12 ++++++++---- .../topic_analysis/demo_pages/explore_topics.py | 9 ++++++--- .../demo_pages/newsletters_generation.py | 2 +- .../demo_pages/topic_visualizations.py | 4 +--- .../demos/topic_analysis/demo_pages/training_page.py | 4 ++++ bertrend/topic_model.py | 3 ++- 6 files changed, 22 insertions(+), 12 deletions(-) diff --git a/bertrend/demos/demos_utils/data_loading_component.py b/bertrend/demos/demos_utils/data_loading_component.py index 3f9ad18..f720d83 100644 --- a/bertrend/demos/demos_utils/data_loading_component.py +++ b/bertrend/demos/demos_utils/data_loading_component.py @@ -125,7 +125,7 @@ def display_data_loading_component(): col1, col2 = st.columns(2) with col1: register_widget("min_chars") - min_chars = st.number_input( + st.number_input( "Minimum Characters", value=MIN_CHARS_DEFAULT, min_value=0, @@ -136,7 +136,7 @@ def display_data_loading_component(): with col2: register_widget("split_by_paragraph") SessionStateManager.get_or_set("split_by_paragraph", "yes") - split_by_paragraph = st.segmented_control( + st.segmented_control( "Split text by paragraphs", key="split_by_paragraph", options=["no", "yes", "enhanced"], @@ -151,11 +151,15 @@ def display_data_loading_component(): dfs = None if SessionStateManager.get("uploaded_files"): dfs = _process_uploaded_files( - SessionStateManager.get("uploaded_files"), min_chars, split_by_paragraph + SessionStateManager.get("uploaded_files"), + SessionStateManager.get("min_chars"), + SessionStateManager.get("split_by_paragraph"), ) elif SessionStateManager.get("selected_files"): dfs = _load_files( - SessionStateManager.get("selected_files"), min_chars, split_by_paragraph + SessionStateManager.get("selected_files"), + SessionStateManager.get("min_chars"), + SessionStateManager.get("split_by_paragraph"), ) if not dfs: diff --git a/bertrend/demos/topic_analysis/demo_pages/explore_topics.py b/bertrend/demos/topic_analysis/demo_pages/explore_topics.py index e9ded0a..02418bc 100644 --- a/bertrend/demos/topic_analysis/demo_pages/explore_topics.py +++ b/bertrend/demos/topic_analysis/demo_pages/explore_topics.py @@ -16,7 +16,7 @@ from loguru import logger -from bertrend import LLM_CONFIG +from bertrend import LLM_CONFIG, OUTPUT_PATH from bertrend.demos.demos_utils.icons import ERROR_ICON, WARNING_ICON from bertrend.demos.demos_utils.session_state_manager import SessionStateManager from bertrend.demos.demos_utils.state_utils import restore_widget_state @@ -36,7 +36,8 @@ ) # Constants -EXPORT_BASE_FOLDER = Path(__file__).parent.parent / "exported_topics" +EXPORT_BASE_FOLDER = OUTPUT_PATH / "exported_topics" +EXPORT_BASE_FOLDER.mkdir(parents=True, exist_ok=True) def generate_topic_description(topic_model, topic_number, filtered_docs): @@ -176,7 +177,7 @@ def plot_topic_over_time(): def get_representative_documents(top_n_docs): """Get representative documents for the selected topic.""" - if st.session_state["split_by_paragraph"] in ["yes", "enhanced"]: + if st.session_state["split_type"] in ["yes", "enhanced"]: return get_most_representative_docs( st.session_state["topic_model"], st.session_state["initial_df"], @@ -482,6 +483,8 @@ def main(): # Restore widget state restore_widget_state() +# FIXME: debug +st.write(st.session_state) main() # FIXME: The number of documents being displayed per topic corresponds to the paragraphs, it should instead correspond to the number of original articles before splitting. diff --git a/bertrend/demos/topic_analysis/demo_pages/newsletters_generation.py b/bertrend/demos/topic_analysis/demo_pages/newsletters_generation.py index 70b31b3..d9b813d 100644 --- a/bertrend/demos/topic_analysis/demo_pages/newsletters_generation.py +++ b/bertrend/demos/topic_analysis/demo_pages/newsletters_generation.py @@ -146,7 +146,7 @@ def main(): # Generate newsletters when button is clicked if generate_newsletter_clicked: - if st.session_state["split_by_paragraph"] in ["yes", "enhanced"]: + if st.session_state["split_type"] in ["yes", "enhanced"]: df = st.session_state["initial_df"] df_split = st.session_state["time_filtered_df"] else: diff --git a/bertrend/demos/topic_analysis/demo_pages/topic_visualizations.py b/bertrend/demos/topic_analysis/demo_pages/topic_visualizations.py index 3649283..ecd33d2 100644 --- a/bertrend/demos/topic_analysis/demo_pages/topic_visualizations.py +++ b/bertrend/demos/topic_analysis/demo_pages/topic_visualizations.py @@ -191,8 +191,6 @@ def create_datamap(include_outliers): enable_search=True, search_field="hover_text", point_line_width=0, - logo="https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/RTE_logo.svg/1024px-RTE_logo.svg.png", - logo_width=100, ) save_path = OUTPUT_PATH / "datamapplot.html" @@ -252,7 +250,7 @@ def main(): components.html(datamap_html, width=1200, height=1000, scrolling=True) # Add the fullscreen button - save_path = Path(__file__).parent.parent / "datamapplot.html" + save_path = OUTPUT_PATH / "datamapplot.html" # Create a download button st.download_button( diff --git a/bertrend/demos/topic_analysis/demo_pages/training_page.py b/bertrend/demos/topic_analysis/demo_pages/training_page.py index 2df76e6..6eb766a 100644 --- a/bertrend/demos/topic_analysis/demo_pages/training_page.py +++ b/bertrend/demos/topic_analysis/demo_pages/training_page.py @@ -169,6 +169,10 @@ def main(): # Load data display_data_loading_component() + # FIXME: debug + st.write(st.session_state) + SessionStateManager.set("split_type", st.session_state["split_by_paragraph"]) + # Data overview if "time_filtered_df" not in st.session_state: st.stop() diff --git a/bertrend/topic_model.py b/bertrend/topic_model.py index 817c8f8..d056300 100644 --- a/bertrend/topic_model.py +++ b/bertrend/topic_model.py @@ -211,7 +211,8 @@ def fit( logger.success("\tBERTopic model fitted successfully") output = TopicModelOutput(topic_model) - + output.topics = new_topics + output.probs = probs return output except Exception as e: logger.error(f"\tError in create_topic_model: {str(e)}")