Skip to content

Commit

Permalink
WIP: first version of script for partial training
Browse files Browse the repository at this point in the history
  • Loading branch information
picaultj committed Feb 4, 2025
1 parent c52f7a5 commit c39e88d
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 81 deletions.
48 changes: 48 additions & 0 deletions bertrend/BERTrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SIGNAL_CLASSIF_LOWER_BOUND,
SIGNAL_CLASSIF_UPPER_BOUND,
)
from bertrend.services.embedding_service import EmbeddingService
from bertrend.trend_analysis.weak_signals import (
_initialize_new_topic,
update_existing_topic,
Expand Down Expand Up @@ -763,6 +764,53 @@ def save_signal_evolution_data(
return save_path


def train_new_data(
new_data: pd.DataFrame,
bertrend_models_path: Path,
embedding_service: EmbeddingService,
) -> BERTrend:
"""Helper function for processing new data (incremental trend analysis:
- loads a previous saved BERTrend model
- train a new topic model with the new data
- merge the models and update merge histories
- save the model and returns it
"""
logger.debug(f"Processing new data: {len(new_data)} items")

# timestamp used to reference the model
reference_timestamp = new_data["timestamp"].max().date()

# Restore previous models
try:
bertrend = BERTrend.restore_models(bertrend_models_path)
except:
logger.warning("Cannot restore previous models, creating new one")
bertrend = BERTrend(topic_model=BERTopicModel())

# Embed new data
embeddings, token_strings, token_embeddings = embedding_service.embed(
texts=new_data[TEXT_COLUMN]
)
embedding_model_name = embedding_service.embedding_model_name

# Create topic model for new data
bertrend.train_topic_models(
{reference_timestamp: new_data},
embeddings=embeddings,
embedding_model=embedding_model_name,
)

# Merge models
bertrend.merge_all_models()

logger.info(f"BERTrend contains {len(bertrend.topic_models)} topic models")

# Save models
bertrend.save_models(models_path=bertrend_models_path)

return bertrend


def _preprocess_model(
topic_model: BERTopic, docs: list[str], embeddings: np.ndarray
) -> pd.DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion bertrend/demos/demos_utils/icons.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
SETTINGS_ICON = ":material/settings:"
TOPIC_ICON = ":material/speaker_notes:"
TREND_ICON = ":material/trending_up:"
MODELS_ICON = ":material/network_intel_node:"
MODELS_ICON = ":material/network_intelligence:"
EMBEDDING_ICON = ":material/memory:"
SAVE_ICON = ":material/save:"
TOPIC_EXPLORATION_ICON = ":material/explore:"
Expand Down
8 changes: 8 additions & 0 deletions bertrend_apps/prospective_demo/.streamlit/secrets.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# .streamlit/secrets.toml

[passwords]
# Follow the rule: username = "password"
jerome = "jerome"
guillaume = "guillaume"
dsia = "dsia"
nemo = "nemo"
46 changes: 46 additions & 0 deletions bertrend_apps/prospective_demo/feeds_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) 2024, RTE (https://www.rte-france.com)
# See AUTHORS.txt
# SPDX-License-Identifier: MPL-2.0
# This file is part of BERTrend.
from pathlib import Path

from loguru import logger

from bertrend import FEED_BASE_PATH, load_toml_config

# Feed config path
USER_FEEDS_BASE_PATH = FEED_BASE_PATH / "users"
USER_FEEDS_BASE_PATH.mkdir(parents=True, exist_ok=True)


def get_user_feed_path(user_name: str, feed_id: str):
feed_path = USER_FEEDS_BASE_PATH / user_name / f"{feed_id}_feed.toml"
return feed_path


def read_user_feeds(username: str) -> tuple[dict[str, dict], dict[str, Path]]:
"""Read user feed config files"""
user_feed_dir = USER_FEEDS_BASE_PATH / username
user_feed_dir.mkdir(parents=True, exist_ok=True)
logger.debug(f"Reading user feeds from: {user_feed_dir}")
matching_files = user_feed_dir.rglob("*_feed.toml")

user_feeds = {}
feed_files = {}
for f in matching_files:
feed_id = f.name.split("_feed.toml")[0]
user_feeds[feed_id] = load_toml_config(f)
feed_files[feed_id] = f

return user_feeds, feed_files


def get_all_files_for_feed(user_feeds: dict[str, dict], feed_id: str) -> list[Path]:
"""Returns the paths of all files associated to a feed for the current user."""
feed_base_dir = user_feeds[feed_id]["data-feed"]["feed_dir_path"]
list_all_files = list(
Path(FEED_BASE_PATH, feed_base_dir).glob(
f"*{user_feeds[feed_id]['data-feed'].get('id')}*.jsonl*"
)
)
return list_all_files
25 changes: 4 additions & 21 deletions bertrend_apps/prospective_demo/feeds_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import toml
from loguru import logger

from bertrend import FEED_BASE_PATH, load_toml_config
from bertrend.config.parameters import LANGUAGES
from bertrend.demos.demos_utils.icons import (
INFO_ICON,
Expand All @@ -30,6 +29,10 @@
schedule_scrapping,
)
from bertrend_apps.data_provider import URL_PATTERN
from bertrend_apps.prospective_demo.feeds_common import (
USER_FEEDS_BASE_PATH,
read_user_feeds,
)
from bertrend_apps.prospective_demo.streamlit_utils import clickable_df

# Default feed configs
Expand All @@ -40,10 +43,6 @@
FEED_SOURCES = ["google", "curebot"]
TRANSLATION = {"English": "Anglais", "French": "Français"}

# Feed config path
USER_FEEDS_BASE_PATH = FEED_BASE_PATH / "users"
USER_FEEDS_BASE_PATH.mkdir(parents=True, exist_ok=True)


@st.dialog("Configuration d'un nouveau flux de données")
def edit_feed_monitoring(config: dict | None = None):
Expand Down Expand Up @@ -159,22 +158,6 @@ def display_crontab_description(crontab_expr: str) -> str:
return f":red[{ERROR_ICON} Expression mal écrite !]"


def read_user_feeds(username: str) -> tuple[dict[str, dict], dict[str, Path]]:
user_feed_dir = USER_FEEDS_BASE_PATH / username
user_feed_dir.mkdir(parents=True, exist_ok=True)
logger.debug(f"Reading user feeds from: {user_feed_dir}")
matching_files = user_feed_dir.rglob("*_feed.toml")

user_feeds = {}
feed_files = {}
for f in matching_files:
feed_id = f.name.split("_feed.toml")[0]
user_feeds[feed_id] = load_toml_config(f)
feed_files[feed_id] = f

return user_feeds, feed_files


def configure_information_sources():
"""Configure Information Sources."""
# if "user_feeds" not in st.session_state:
Expand Down
63 changes: 8 additions & 55 deletions bertrend_apps/prospective_demo/feeds_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
# See AUTHORS.txt
# SPDX-License-Identifier: MPL-2.0
# This file is part of BERTrend.
import os
import datetime
from pathlib import Path

import pandas as pd
import streamlit as st
from loguru import logger

from bertrend import FEED_BASE_PATH
from bertrend.utils.data_loading import (
Expand All @@ -18,6 +16,7 @@
URL_COLUMN,
TEXT_COLUMN,
)
from bertrend_apps.prospective_demo.feeds_common import get_all_files_for_feed


def display_data_status():
Expand Down Expand Up @@ -47,15 +46,16 @@ def display_data_status():


def display_data_info_for_feed(feed_id: str):
all_files = get_all_files_for_feed(feed_id)
all_files = get_all_files_for_feed(st.session_state.user_feeds, feed_id)
df = get_all_data(files=all_files)
df = df[
[TITLE_COLUMN, URL_COLUMN, TEXT_COLUMN, TIMESTAMP_COLUMN]
] # filter useful columns

if df.empty:
df_filtered = pd.DataFrame()
else:
df = df[
[TITLE_COLUMN, URL_COLUMN, TEXT_COLUMN, TIMESTAMP_COLUMN]
] # filter useful columns

cutoff_date = datetime.datetime.now() - datetime.timedelta(
days=st.session_state.time_window
)
Expand All @@ -66,10 +66,8 @@ def display_data_info_for_feed(feed_id: str):
"# Fichiers": len(all_files),
"Date début": df[TIMESTAMP_COLUMN].min() if not df.empty else None,
"Date fin": df[TIMESTAMP_COLUMN].max() if not df.empty else None,
"Nombre d'articles": len(df),
f"Nombre d'articles (derniers {st.session_state.time_window} jours)": len(
df_filtered
),
"# Articles": len(df),
f"# Articles ({st.session_state.time_window} derniers jours)": len(df_filtered),
}

st.dataframe(pd.DataFrame([stats]))
Expand All @@ -93,48 +91,3 @@ def get_all_data(files: list[Path]) -> pd.DataFrame:
subset=["title"], keep="first", inplace=False
)
return new_df


def get_all_files_for_feed(feed_id: str) -> list[Path]:
"""Returns the paths of all files associated to a feed for the current user."""
feed_base_dir = st.session_state.user_feeds[feed_id]["data-feed"]["feed_dir_path"]
list_all_files = list(
Path(FEED_BASE_PATH, feed_base_dir).glob(
f"*{st.session_state.user_feeds[feed_id]['data-feed'].get('id')}*.jsonl*"
)
)
return list_all_files


def get_last_files(files: list[Path], time_window: int) -> list[Path] | None:
"""Returns the paths of all files associated to a feed for the current user in the last time window."""
cutoff_date = datetime.datetime.now() - datetime.timedelta(days=time_window)
matching_files = []
for file in files:
try:
file_stat = file.stat() # Get file stats only once
print(file_stat)
file_time = datetime.datetime.fromtimestamp(file_stat.st_mtime)
if file_time >= cutoff_date:
matching_files.append(file)
except OSError as e:
logger.warning(f"Error accessing file {file}: {e}")
# Handle the error as needed (e.g., skip the file)


def get_first_file(files: list[Path]) -> Path | None:
"""Returns the first file associated to a feed for the current user."""
if files: # Check if any files were found
first_file = min(files, key=os.path.getctime)
else:
first_file = None # Or handle the case where no files are found appropriately. Perhaps raise an exception.
return first_file


def get_last_file(files: list[Path]) -> Path | None:
"""Returns the last file associated to a feed for the current user."""
if files: # Check if any files were found
latest_file = max(files, key=os.path.getctime)
else:
latest_file = None # Or handle the case where no files are found appropriately. Perhaps raise an exception.
return latest_file
Loading

0 comments on commit c39e88d

Please sign in to comment.