From 69cbfc05d6f55f513d6517f6b8b485a2f96efe19 Mon Sep 17 00:00:00 2001 From: Guillaume Grosjean Date: Wed, 8 Jan 2025 11:12:59 +0100 Subject: [PATCH] Modified data loading options and rendering in streamlit demo --- .../demos_utils/data_loading_component.py | 138 ++++++++++-------- bertrend/parameters.py | 4 +- 2 files changed, 78 insertions(+), 64 deletions(-) diff --git a/bertrend/demos/demos_utils/data_loading_component.py b/bertrend/demos/demos_utils/data_loading_component.py index ce43cd8..2796c39 100644 --- a/bertrend/demos/demos_utils/data_loading_component.py +++ b/bertrend/demos/demos_utils/data_loading_component.py @@ -100,6 +100,9 @@ def display_data_loading_component(): split by paragraph from initial documents and in all cases filtered by dates) is stored inside the Streamlit state variable "time_filtered_df". """ + # Data loading section + st.header("Data loading") + # Find files in the current directory and subdirectories tab1, tab2 = st.tabs( [ @@ -136,32 +139,7 @@ def display_data_loading_component(): st.warning(NO_DATASET_WARNING, icon=WARNING_ICON) st.stop() - # Display number input and checkbox for preprocessing options - col1, col2 = st.columns(2) - with col1: - register_widget("min_chars") - st.number_input( - "Minimum Characters", - value=MIN_CHARS_DEFAULT, - min_value=0, - max_value=1000, - key="min_chars", - on_change=save_widget_state, - ) - with col2: - register_widget("split_by_paragraph") - SessionStateManager.get_or_set("split_by_paragraph", "yes") - st.segmented_control( - "Split text by paragraphs", - key="split_by_paragraph", - options=["no", "yes", "enhanced"], - selection_mode="single", - help="'No split': No splitting on the documents ; 'Split by paragraphs': Split documents into paragraphs ; " - "'Enhanced split': uses a more advanced but slower method for splitting that considers the embedding " - "model's maximum input length.", - on_change=save_widget_state, - ) - # Load and preprocess each selected file, then concatenate them + # Load each selected file, then concatenate them # Priority to local data if both are set dfs = None if SessionStateManager.get("uploaded_files"): @@ -172,18 +150,65 @@ def display_data_loading_component(): dfs = _load_files( SessionStateManager.get("selected_files"), ) - + # Check if DataFrames have been found if not dfs: st.warning( NO_DATA_AFTER_PREPROCESSING_MESSAGE, icon=WARNING_ICON, ) else: + # Concatenate DataFrames df = pd.concat(dfs, ignore_index=True) # Save state of initial DF (before split and data selection) st.session_state["initial_df"] = df.copy() + # Show raw data info + st.write( + f"Number of documents in raw data: **{len(st.session_state['initial_df'])}**" + ) + + # Data filtering section + st.header("Data filtering") + + # Display number input and checkbox for preprocessing options + col1, col2, col3 = st.columns(3) + with col1: + register_widget("min_chars") + st.number_input( + "Minimum Characters", + value=MIN_CHARS_DEFAULT, + min_value=0, + max_value=1000, + key="min_chars", + on_change=save_widget_state, + help="Minimum number of characters each document must contain.", + ) + with col2: + register_widget("sample_size") + sample_size = st.number_input( + "Sample ratio", + value=SAMPLE_SIZE_DEFAULT, + min_value=0.0, + max_value=1.0, + key="sample_size", + on_change=save_widget_state, + help="Fraction of raw data to use for computing topics. Randomly samples documents from raw data.", + ) + with col3: + register_widget("split_by_paragraph") + SessionStateManager.get_or_set("split_by_paragraph", "no") + st.segmented_control( + "Split text by paragraphs", + key="split_by_paragraph", + options=["no", "yes", "enhanced"], + selection_mode="single", + help="'No split': No splitting on the documents ; 'Split by paragraphs': Split documents into paragraphs ; " + "'Enhanced split': uses a more advanced but slower method for splitting that considers the embedding " + "model's maximum input length.", + on_change=save_widget_state, + ) + df = split_data( df, SessionStateManager.get("min_chars"), @@ -197,42 +222,31 @@ def display_data_loading_component(): # Save state of split dataframe (before time-based filtering) st.session_state["split_df"] = df.copy() - col1, col2 = st.columns([0.8, 0.2]) - with col1: - # Select timeframe - min_date, max_date = df[TIMESTAMP_COLUMN].dt.date.agg(["min", "max"]) - register_widget("timeframe_slider") - start_date, end_date = st.slider( - "Select Timeframe", - min_value=min_date, - max_value=max_date, - value=(min_date, max_date), - key="timeframe_slider", - on_change=save_widget_state, - ) - - # Filter and sample the DataFrame - df_filtered = df[ - (df[TIMESTAMP_COLUMN].dt.date >= start_date) - & (df[TIMESTAMP_COLUMN].dt.date <= end_date) - ] - df_filtered = df_filtered.sort_values(by=TIMESTAMP_COLUMN).reset_index( - drop=True - ) + # Select timeframe + min_date, max_date = st.session_state["initial_df"][ + TIMESTAMP_COLUMN + ].dt.date.agg(["min", "max"]) + register_widget("timeframe_slider") + start_date, end_date = st.slider( + "Select Timeframe", + min_value=min_date, + max_value=max_date, + value=(min_date, max_date), + key="timeframe_slider", + on_change=save_widget_state, + ) - with col2: - register_widget("sample_size") - sample_size = st.number_input( - "Sample Size", - value=SAMPLE_SIZE_DEFAULT or len(df_filtered), - min_value=1, - max_value=len(df_filtered), - key="sample_size", - on_change=save_widget_state, - ) + # Filter and sample the DataFrame + df_filtered = df[ + (df[TIMESTAMP_COLUMN].dt.date >= start_date) + & (df[TIMESTAMP_COLUMN].dt.date <= end_date) + ] + df_filtered = df_filtered.sort_values(by=TIMESTAMP_COLUMN).reset_index( + drop=True + ) - if sample_size < len(df_filtered): - df_filtered = df_filtered.sample(n=sample_size, random_state=42) + if sample_size < 1: + df_filtered = df_filtered.sample(frac=sample_size, random_state=42) df_filtered = df_filtered.sort_values(by=TIMESTAMP_COLUMN).reset_index( drop=True @@ -240,7 +254,7 @@ def display_data_loading_component(): SessionStateManager.set("time_filtered_df", df_filtered) st.write( - f"Number of documents in selected timeframe: {len(SessionStateManager.get_dataframe('time_filtered_df'))}" + f"Number of documents in filtered data: **{len(SessionStateManager.get_dataframe('time_filtered_df'))}**" ) st.dataframe( SessionStateManager.get_dataframe("time_filtered_df")[ diff --git a/bertrend/parameters.py b/bertrend/parameters.py index 660c889..831ac76 100644 --- a/bertrend/parameters.py +++ b/bertrend/parameters.py @@ -101,8 +101,8 @@ OPENAI_NR_DOCS = 5 # Data Processing -MIN_CHARS_DEFAULT = 100 -SAMPLE_SIZE_DEFAULT = None # Or whatever default you want, None means all documents +MIN_CHARS_DEFAULT = 1 +SAMPLE_SIZE_DEFAULT = 1.0 # Or whatever default you want, None means all documents # Time Settings DEFAULT_WINDOW_SIZE = 7 # days