Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
picaultj committed Dec 30, 2024
1 parent bfdc296 commit 5202019
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 12 deletions.
12 changes: 8 additions & 4 deletions bertrend/demos/demos_utils/data_loading_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def display_data_loading_component():
col1, col2 = st.columns(2)
with col1:
register_widget("min_chars")
min_chars = st.number_input(
st.number_input(
"Minimum Characters",
value=MIN_CHARS_DEFAULT,
min_value=0,
Expand All @@ -136,7 +136,7 @@ def display_data_loading_component():
with col2:
register_widget("split_by_paragraph")
SessionStateManager.get_or_set("split_by_paragraph", "yes")
split_by_paragraph = st.segmented_control(
st.segmented_control(
"Split text by paragraphs",
key="split_by_paragraph",
options=["no", "yes", "enhanced"],
Expand All @@ -151,11 +151,15 @@ def display_data_loading_component():
dfs = None
if SessionStateManager.get("uploaded_files"):
dfs = _process_uploaded_files(
SessionStateManager.get("uploaded_files"), min_chars, split_by_paragraph
SessionStateManager.get("uploaded_files"),
SessionStateManager.get("min_chars"),
SessionStateManager.get("split_by_paragraph"),
)
elif SessionStateManager.get("selected_files"):
dfs = _load_files(
SessionStateManager.get("selected_files"), min_chars, split_by_paragraph
SessionStateManager.get("selected_files"),
SessionStateManager.get("min_chars"),
SessionStateManager.get("split_by_paragraph"),
)

if not dfs:
Expand Down
9 changes: 6 additions & 3 deletions bertrend/demos/topic_analysis/demo_pages/explore_topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from loguru import logger

from bertrend import LLM_CONFIG
from bertrend import LLM_CONFIG, OUTPUT_PATH
from bertrend.demos.demos_utils.icons import ERROR_ICON, WARNING_ICON
from bertrend.demos.demos_utils.session_state_manager import SessionStateManager
from bertrend.demos.demos_utils.state_utils import restore_widget_state
Expand All @@ -36,7 +36,8 @@
)

# Constants
EXPORT_BASE_FOLDER = Path(__file__).parent.parent / "exported_topics"
EXPORT_BASE_FOLDER = OUTPUT_PATH / "exported_topics"
EXPORT_BASE_FOLDER.mkdir(parents=True, exist_ok=True)


def generate_topic_description(topic_model, topic_number, filtered_docs):
Expand Down Expand Up @@ -176,7 +177,7 @@ def plot_topic_over_time():

def get_representative_documents(top_n_docs):
"""Get representative documents for the selected topic."""
if st.session_state["split_by_paragraph"] in ["yes", "enhanced"]:
if st.session_state["split_type"] in ["yes", "enhanced"]:
return get_most_representative_docs(
st.session_state["topic_model"],
st.session_state["initial_df"],
Expand Down Expand Up @@ -482,6 +483,8 @@ def main():

# Restore widget state
restore_widget_state()
# FIXME: debug
st.write(st.session_state)
main()

# FIXME: The number of documents being displayed per topic corresponds to the paragraphs, it should instead correspond to the number of original articles before splitting.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def main():

# Generate newsletters when button is clicked
if generate_newsletter_clicked:
if st.session_state["split_by_paragraph"] in ["yes", "enhanced"]:
if st.session_state["split_type"] in ["yes", "enhanced"]:
df = st.session_state["initial_df"]
df_split = st.session_state["time_filtered_df"]
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ def create_datamap(include_outliers):
enable_search=True,
search_field="hover_text",
point_line_width=0,
logo="https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/RTE_logo.svg/1024px-RTE_logo.svg.png",
logo_width=100,
)

save_path = OUTPUT_PATH / "datamapplot.html"
Expand Down Expand Up @@ -252,7 +250,7 @@ def main():
components.html(datamap_html, width=1200, height=1000, scrolling=True)

# Add the fullscreen button
save_path = Path(__file__).parent.parent / "datamapplot.html"
save_path = OUTPUT_PATH / "datamapplot.html"

# Create a download button
st.download_button(
Expand Down
4 changes: 4 additions & 0 deletions bertrend/demos/topic_analysis/demo_pages/training_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def main():
# Load data
display_data_loading_component()

# FIXME: debug
st.write(st.session_state)
SessionStateManager.set("split_type", st.session_state["split_by_paragraph"])

# Data overview
if "time_filtered_df" not in st.session_state:
st.stop()
Expand Down
3 changes: 2 additions & 1 deletion bertrend/topic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ def fit(

logger.success("\tBERTopic model fitted successfully")
output = TopicModelOutput(topic_model)

output.topics = new_topics
output.probs = probs
return output
except Exception as e:
logger.error(f"\tError in create_topic_model: {str(e)}")
Expand Down

0 comments on commit 5202019

Please sign in to comment.