Skip to content

Commit

Permalink
Modified data loading options and rendering in streamlit demo
Browse files Browse the repository at this point in the history
  • Loading branch information
grosjeang committed Jan 8, 2025
1 parent 1ba936e commit 69cbfc0
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 64 deletions.
138 changes: 76 additions & 62 deletions bertrend/demos/demos_utils/data_loading_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def display_data_loading_component():
split by paragraph from initial documents and in all cases filtered by dates) is stored inside the Streamlit
state variable "time_filtered_df".
"""
# Data loading section
st.header("Data loading")

# Find files in the current directory and subdirectories
tab1, tab2 = st.tabs(
[
Expand Down Expand Up @@ -136,32 +139,7 @@ def display_data_loading_component():
st.warning(NO_DATASET_WARNING, icon=WARNING_ICON)
st.stop()

# Display number input and checkbox for preprocessing options
col1, col2 = st.columns(2)
with col1:
register_widget("min_chars")
st.number_input(
"Minimum Characters",
value=MIN_CHARS_DEFAULT,
min_value=0,
max_value=1000,
key="min_chars",
on_change=save_widget_state,
)
with col2:
register_widget("split_by_paragraph")
SessionStateManager.get_or_set("split_by_paragraph", "yes")
st.segmented_control(
"Split text by paragraphs",
key="split_by_paragraph",
options=["no", "yes", "enhanced"],
selection_mode="single",
help="'No split': No splitting on the documents ; 'Split by paragraphs': Split documents into paragraphs ; "
"'Enhanced split': uses a more advanced but slower method for splitting that considers the embedding "
"model's maximum input length.",
on_change=save_widget_state,
)
# Load and preprocess each selected file, then concatenate them
# Load each selected file, then concatenate them
# Priority to local data if both are set
dfs = None
if SessionStateManager.get("uploaded_files"):
Expand All @@ -172,18 +150,65 @@ def display_data_loading_component():
dfs = _load_files(
SessionStateManager.get("selected_files"),
)

# Check if DataFrames have been found
if not dfs:
st.warning(
NO_DATA_AFTER_PREPROCESSING_MESSAGE,
icon=WARNING_ICON,
)
else:
# Concatenate DataFrames
df = pd.concat(dfs, ignore_index=True)

# Save state of initial DF (before split and data selection)
st.session_state["initial_df"] = df.copy()

# Show raw data info
st.write(
f"Number of documents in raw data: **{len(st.session_state['initial_df'])}**"
)

# Data filtering section
st.header("Data filtering")

# Display number input and checkbox for preprocessing options
col1, col2, col3 = st.columns(3)
with col1:
register_widget("min_chars")
st.number_input(
"Minimum Characters",
value=MIN_CHARS_DEFAULT,
min_value=0,
max_value=1000,
key="min_chars",
on_change=save_widget_state,
help="Minimum number of characters each document must contain.",
)
with col2:
register_widget("sample_size")
sample_size = st.number_input(
"Sample ratio",
value=SAMPLE_SIZE_DEFAULT,
min_value=0.0,
max_value=1.0,
key="sample_size",
on_change=save_widget_state,
help="Fraction of raw data to use for computing topics. Randomly samples documents from raw data.",
)
with col3:
register_widget("split_by_paragraph")
SessionStateManager.get_or_set("split_by_paragraph", "no")
st.segmented_control(
"Split text by paragraphs",
key="split_by_paragraph",
options=["no", "yes", "enhanced"],
selection_mode="single",
help="'No split': No splitting on the documents ; 'Split by paragraphs': Split documents into paragraphs ; "
"'Enhanced split': uses a more advanced but slower method for splitting that considers the embedding "
"model's maximum input length.",
on_change=save_widget_state,
)

df = split_data(
df,
SessionStateManager.get("min_chars"),
Expand All @@ -197,50 +222,39 @@ def display_data_loading_component():
# Save state of split dataframe (before time-based filtering)
st.session_state["split_df"] = df.copy()

col1, col2 = st.columns([0.8, 0.2])
with col1:
# Select timeframe
min_date, max_date = df[TIMESTAMP_COLUMN].dt.date.agg(["min", "max"])
register_widget("timeframe_slider")
start_date, end_date = st.slider(
"Select Timeframe",
min_value=min_date,
max_value=max_date,
value=(min_date, max_date),
key="timeframe_slider",
on_change=save_widget_state,
)

# Filter and sample the DataFrame
df_filtered = df[
(df[TIMESTAMP_COLUMN].dt.date >= start_date)
& (df[TIMESTAMP_COLUMN].dt.date <= end_date)
]
df_filtered = df_filtered.sort_values(by=TIMESTAMP_COLUMN).reset_index(
drop=True
)
# Select timeframe
min_date, max_date = st.session_state["initial_df"][
TIMESTAMP_COLUMN
].dt.date.agg(["min", "max"])
register_widget("timeframe_slider")
start_date, end_date = st.slider(
"Select Timeframe",
min_value=min_date,
max_value=max_date,
value=(min_date, max_date),
key="timeframe_slider",
on_change=save_widget_state,
)

with col2:
register_widget("sample_size")
sample_size = st.number_input(
"Sample Size",
value=SAMPLE_SIZE_DEFAULT or len(df_filtered),
min_value=1,
max_value=len(df_filtered),
key="sample_size",
on_change=save_widget_state,
)
# Filter and sample the DataFrame
df_filtered = df[
(df[TIMESTAMP_COLUMN].dt.date >= start_date)
& (df[TIMESTAMP_COLUMN].dt.date <= end_date)
]
df_filtered = df_filtered.sort_values(by=TIMESTAMP_COLUMN).reset_index(
drop=True
)

if sample_size < len(df_filtered):
df_filtered = df_filtered.sample(n=sample_size, random_state=42)
if sample_size < 1:
df_filtered = df_filtered.sample(frac=sample_size, random_state=42)

df_filtered = df_filtered.sort_values(by=TIMESTAMP_COLUMN).reset_index(
drop=True
)

SessionStateManager.set("time_filtered_df", df_filtered)
st.write(
f"Number of documents in selected timeframe: {len(SessionStateManager.get_dataframe('time_filtered_df'))}"
f"Number of documents in filtered data: **{len(SessionStateManager.get_dataframe('time_filtered_df'))}**"
)
st.dataframe(
SessionStateManager.get_dataframe("time_filtered_df")[
Expand Down
4 changes: 2 additions & 2 deletions bertrend/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@
OPENAI_NR_DOCS = 5

# Data Processing
MIN_CHARS_DEFAULT = 100
SAMPLE_SIZE_DEFAULT = None # Or whatever default you want, None means all documents
MIN_CHARS_DEFAULT = 1
SAMPLE_SIZE_DEFAULT = 1.0 # Or whatever default you want, None means all documents

# Time Settings
DEFAULT_WINDOW_SIZE = 7 # days
Expand Down

0 comments on commit 69cbfc0

Please sign in to comment.