diff --git a/bertrend/demos/demos_utils/icons.py b/bertrend/demos/demos_utils/icons.py index 3a19b06..e55e23c 100644 --- a/bertrend/demos/demos_utils/icons.py +++ b/bertrend/demos/demos_utils/icons.py @@ -6,6 +6,9 @@ WARNING_ICON = ":material/warning:" ERROR_ICON = ":material/error:" INFO_ICON = ":material/info:" +EDIT_ICON = ":material/edit:" +ADD_ICON = ":material/add_circle:" +DELETE_ICON = ":material/delete:" SUCCESS_ICON = ":material/check:" SETTINGS_ICON = ":material/settings:" TOPIC_ICON = ":material/speaker_notes:" @@ -20,6 +23,9 @@ MODEL_TRAINING_ICON = ":material/cognition:" SERVER_STORAGE_ICON = ":material/database:" CLIENT_STORAGE_ICON = ":material/upload:" +UNHAPPY_ICON = ":material/sentiment_extremely_dissatisfied:" +TOGGLE_ON_ICON = ":material/toggle_on:" +TOGGLE_OFF_ICON = ":material/toggle_off:" JSON_ICON = "đŸ§Ÿ" PARQUET_ICON = "đŸ“Šïž" diff --git a/bertrend_apps/common/crontab_utils.py b/bertrend_apps/common/crontab_utils.py index 11f68bb..6c3a047 100644 --- a/bertrend_apps/common/crontab_utils.py +++ b/bertrend_apps/common/crontab_utils.py @@ -3,27 +3,76 @@ # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. import os +import re import subprocess import sys from pathlib import Path +from cron_descriptor import ( + Options, + CasingTypeEnum, + ExpressionDescriptor, + DescriptionTypeEnum, +) from loguru import logger from bertrend import BEST_CUDA_DEVICE, BERTREND_LOG_PATH, load_toml_config -def add_job_to_crontab(schedule, command, env_vars=""): - logger.info(f"Adding to crontab: {schedule} {command}") +def get_understandable_cron_description(cron_expression: str) -> str: + """Returns a human understandable crontab description.""" + options = Options() + options.casing_type = CasingTypeEnum.Sentence + options.use_24hour_time_format = True + options.locale_code = "fr_FR" + descriptor = ExpressionDescriptor(cron_expression, options) + return descriptor.get_description(DescriptionTypeEnum.FULL) + + +def add_job_to_crontab(schedule, command, env_vars="") -> bool: + """Add the specified job to the crontab.""" + logger.debug(f"Adding to crontab: {schedule} {command}") home = os.getenv("HOME") # Create crontab, add command - NB: we use the .bashrc to source all environment variables that may be required by the command cmd = f'(crontab -l; echo "{schedule} umask 002; source {home}/.bashrc; {env_vars} {command}" ) | crontab -' returned_value = subprocess.call(cmd, shell=True) # returns the exit code in unix - logger.info(f"Crontab updated with status {returned_value}") + return returned_value == 0 -def schedule_scrapping( - feed_cfg: Path, -): +def check_cron_job(pattern: str) -> bool: + """Check if a specific pattern (expressed as a regular expression) matches crontab entries.""" + try: + # Run `crontab -l` and capture the output + result = subprocess.run( + ["crontab", "-l"], capture_output=True, text=True, check=True + ) + + # Search for the regex pattern in the crontab output + if re.search(pattern, result.stdout): + return True + else: + return False + except subprocess.CalledProcessError: + # If crontab fails (e.g., no crontab for the user), return False + return False + + +def remove_from_crontab(pattern: str) -> bool: + """Removes from the crontab the job matching the provided pattern (expressed as a regular expression)""" + if not (check_cron_job(pattern)): + logger.warning("No job matching the provided pattern") + return False + try: + # Retrieve current crontab + output = subprocess.check_output( + f"crontab -l | grep -vE {pattern} | crontab -", shell=True + ) + return output == 0 + except subprocess.CalledProcessError: + return False + + +def schedule_scrapping(feed_cfg: Path): """Schedule data scrapping on the basis of a feed configuration file""" data_feed_cfg = load_toml_config(feed_cfg) schedule = data_feed_cfg["data-feed"]["update_frequency"] @@ -44,3 +93,21 @@ def schedule_newsletter( command = f"{sys.prefix}/bin/python -m bertrend_apps.newsletters newsletters {newsletter_cfg_path.resolve()} {data_feed_cfg_path.resolve()} > {BERTREND_LOG_PATH}/cron_newsletter_{id}.log 2>&1" env_vars = f"CUDA_VISIBLE_DEVICES={cuda_devices}" add_job_to_crontab(schedule, command, env_vars) + + +def check_if_scrapping_active_for_user(feed_id: str, user: str = None) -> bool: + """Checks if a given scrapping feed is active (registered in the crontab""" + if user: + return check_cron_job(rf"scrape-feed.*/feeds/users/{user}/{feed_id}_feed.toml") + else: + return check_cron_job(rf"scrape-feed.*/feeds/{feed_id}_feed.toml") + + +def remove_scrapping_for_user(feed_id: str, user: str = None): + """Removes from the crontab the job matching the provided feed_id""" + if user: + return remove_from_crontab( + rf"scrape-feed.*/feeds/users/{user}/{feed_id}_feed.toml" + ) + else: + return remove_from_crontab(rf"scrape-feed.*/feeds/{feed_id}_feed.toml") diff --git a/bertrend_apps/data_provider/__init__.py b/bertrend_apps/data_provider/__init__.py index ae6e745..0c6306d 100644 --- a/bertrend_apps/data_provider/__init__.py +++ b/bertrend_apps/data_provider/__init__.py @@ -2,3 +2,8 @@ # See AUTHORS.txt # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. +# Define a pattern for a basic URL validation +URL_PATTERN = ( + r"^(https?://)?([a-z0-9-]+\.)+[a-z]{2,6}(:\d+)?(/[\w.-]*)*$|" + r"^(https?://)?(localhost|(\d{1,3}\.){3}\d{1,3})(:\d+)?(/[\w.-]*)*$" +) diff --git a/bertrend_apps/data_provider/__main__.py b/bertrend_apps/data_provider/__main__.py index 953cfc7..95c032a 100644 --- a/bertrend_apps/data_provider/__main__.py +++ b/bertrend_apps/data_provider/__main__.py @@ -15,6 +15,7 @@ from bertrend_apps.common.crontab_utils import schedule_scrapping from bertrend_apps.data_provider.arxiv_provider import ArxivProvider from bertrend_apps.data_provider.bing_news_provider import BingNewsProvider +from bertrend_apps.data_provider.curebot_provider import CurebotProvider from bertrend_apps.data_provider.google_news_provider import GoogleNewsProvider from bertrend_apps.data_provider.newscatcher_provider import NewsCatcherProvider @@ -23,6 +24,7 @@ PROVIDERS = { "arxiv": ArxivProvider, + "curebot": CurebotProvider, "google": GoogleNewsProvider, "bing": BingNewsProvider, "newscatcher": NewsCatcherProvider, @@ -46,7 +48,7 @@ def scrape( max_results: int = typer.Option( 50, help="maximum number of results per request" ), - save_path: str = typer.Option( + save_path: Path = typer.Option( None, help="Path for writing results. File is in jsonl format." ), language: str = typer.Option(None, help="Language filter"), @@ -65,7 +67,7 @@ def scrape( "to" date, formatted as YYYY-MM-DD max_results: int Maximum number of results per request - save_path: str + save_path: Path Path to the output file (jsonl format) language: str Language filter @@ -90,7 +92,7 @@ def auto_scrape( provider: str = typer.Option( "google", help="source for news [google, bing, newscatcher]" ), - save_path: str = typer.Option(None, help="Path for writing results."), + save_path: Path = typer.Option(None, help="Path for writing results."), language: str = typer.Option(None, help="Language filter"), ): """Scrape data from Arxiv, Google, Bing news or NewsCatcher (multiple requests from a configuration file: each line of the file shall be compliant with the following format: @@ -100,9 +102,11 @@ def auto_scrape( ---------- requests_file: str Text file containing the list of requests to be processed + max_results: int + Maximum number of results per request provider: str News data provider. Current authorized values [google, bing, newscatcher] - save_path: str + save_path: Path Path to the output file (jsonl format) language: str Language filter @@ -178,7 +182,7 @@ def _daterange(start_date, end_date, ndays): @app.command("scrape-feed") def scrape_from_feed( - feed_cfg: str = typer.Argument(help="Path of the data feed config file"), + feed_cfg: Path = typer.Argument(help="Path of the data feed config file"), ): """Scrape data from Arxiv, Google, Bing news or NewsCatcher on the basis of a feed configuration file""" data_feed_cfg = load_toml_config(feed_cfg) @@ -200,7 +204,7 @@ def scrape_from_feed( # Generate a query file with tempfile.NamedTemporaryFile() as query_file: - if provider == "arxiv": # already returns batches + if provider == "arxiv" or provider == "curebot": # already returns batches scrape( keywords=keywords, provider=provider, diff --git a/bertrend_apps/data_provider/curebot_provider.py b/bertrend_apps/data_provider/curebot_provider.py index 79aa053..698229b 100644 --- a/bertrend_apps/data_provider/curebot_provider.py +++ b/bertrend_apps/data_provider/curebot_provider.py @@ -2,11 +2,13 @@ # See AUTHORS.txt # SPDX-License-Identifier: MPL-2.0 # This file is part of BERTrend. +import re from pathlib import Path import pandas as pd from loguru import logger +from bertrend_apps.data_provider import URL_PATTERN from bertrend_apps.data_provider.data_provider import DataProvider import feedparser @@ -19,6 +21,8 @@ def __init__(self, curebot_export_file: Path = None, feed_url: str = None): self.data_file = curebot_export_file if self.data_file: self.df_dict = pd.read_excel(self.data_file, sheet_name=None, dtype=str) + else: + self.df_dict = None self.feed_url = feed_url def get_articles( @@ -30,16 +34,22 @@ def get_articles( language: str = "fr", ) -> list[dict]: """Requests the news data provider, collects a set of URLs to be parsed, return results as json lines""" + if query and re.match(URL_PATTERN, query): + # if using a config file, the "query" field may contain the feed url + self.feed_url = query if self.feed_url: return self.parse_ATOM_feed() - entries = [] - for k in self.df_dict.keys(): - entries += self.df_dict[k].to_dict(orient="records") - results = [self._parse_entry(res) for res in entries] - return [ - res for res in results if res is not None - ] # sanity check to remove errors + if self.df_dict: + entries = [] + for k in self.df_dict.keys(): + entries += self.df_dict[k].to_dict(orient="records") + results = [self._parse_entry(res) for res in entries] + return [ + res for res in results if res is not None + ] # sanity check to remove errors + + return [] def parse_ATOM_feed(self) -> list[dict]: feed = feedparser.parse(self.feed_url) diff --git a/bertrend_apps/data_provider/data_provider.py b/bertrend_apps/data_provider/data_provider.py index 29cb3ab..9607c8a 100644 --- a/bertrend_apps/data_provider/data_provider.py +++ b/bertrend_apps/data_provider/data_provider.py @@ -107,7 +107,8 @@ def store_articles(self, data: list[dict], file_path: Path): if not data: logger.error("No data to be stored!") return -1 - with jsonlines.open(file_path, "w") as writer: + with jsonlines.open(file_path, "a") as writer: + # append to existing file writer.write_all(data) logger.info(f"Data stored to {file_path} [{len(data)} entries].") diff --git a/bertrend_apps/prospective_demo/__init__.py b/bertrend_apps/prospective_demo/__init__.py new file mode 100644 index 0000000..ae6e745 --- /dev/null +++ b/bertrend_apps/prospective_demo/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. diff --git a/bertrend_apps/prospective_demo/app.py b/bertrend_apps/prospective_demo/app.py new file mode 100644 index 0000000..b66d888 --- /dev/null +++ b/bertrend_apps/prospective_demo/app.py @@ -0,0 +1,81 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +from typing import Literal + +import streamlit as st + +from bertrend.demos.demos_utils import is_admin_mode +from bertrend.demos.demos_utils.icons import ( + SETTINGS_ICON, + ANALYSIS_ICON, + NEWSLETTER_ICON, + SERVER_STORAGE_ICON, +) +from bertrend.demos.demos_utils.state_utils import SessionStateManager +from bertrend_apps.prospective_demo.authentication import check_password +from bertrend_apps.prospective_demo.dashboard_analysis import dashboard_analysis +from bertrend_apps.prospective_demo.feeds_config import configure_information_sources +from bertrend_apps.prospective_demo.feeds_data import display_data_status + +# UI Settings +PAGE_TITLE = "BERTrend - Prospective Analysis demo" +LAYOUT: Literal["centered", "wide"] = "wide" + +# TODO: reactivate password +# AUTHENTIFICATION = True +AUTHENTIFICATION = False + + +def main(): + """Main page""" + st.set_page_config( + page_title=PAGE_TITLE, + layout=LAYOUT, + initial_sidebar_state="expanded" if is_admin_mode() else "collapsed", + page_icon=":part_alternation_mark:", + ) + + st.title(":part_alternation_mark: " + PAGE_TITLE) + + if AUTHENTIFICATION: + username = check_password() + if not username: + st.stop() + else: + SessionStateManager.set("username", username) + else: + SessionStateManager.get_or_set( + "username", "nemo" + ) # if username is not set or authentication deactivated + + # Sidebar + with st.sidebar: + st.header(SETTINGS_ICON + " Settings and Controls") + + # Main content + tab1, tab2 = st.tabs( + [ + NEWSLETTER_ICON + " Mes veilles", + ANALYSIS_ICON + " Mes analyses", + ] + ) + + with tab1: + with st.expander( + "Configuration des flux de donnĂ©es", expanded=True, icon=SETTINGS_ICON + ): + configure_information_sources() + + with st.expander( + "Etat de collecte des donnĂ©es", expanded=False, icon=SERVER_STORAGE_ICON + ): + display_data_status() + + with tab2: + dashboard_analysis() + + +if __name__ == "__main__": + main() diff --git a/bertrend_apps/prospective_demo/authentication.py b/bertrend_apps/prospective_demo/authentication.py new file mode 100644 index 0000000..94e3967 --- /dev/null +++ b/bertrend_apps/prospective_demo/authentication.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +import hmac + +import streamlit as st + +from bertrend.demos.demos_utils.icons import UNHAPPY_ICON + + +def login_form(): + """Form with widgets to collect user information""" + with st.form("Credentials"): + st.text_input("Username", key="username") + st.text_input("Password", type="password", key="password") + st.form_submit_button("Log in", on_click=password_entered) + + +def password_entered(): + """Checks whether a password entered by the user is correct.""" + if st.session_state["username"] in st.secrets["passwords"] and hmac.compare_digest( + st.session_state["password"], + st.secrets.passwords[st.session_state["username"]], + ): + st.session_state["password_correct"] = True + del st.session_state["password"] # Don't store the username or password. + # del st.session_state["username"] + else: + st.session_state["password_correct"] = False + + +def check_password() -> str | None: + """Returns the user name if the user had a correct password, otherwise None.""" + + # Return True if the username + password is validated. + if st.session_state.get("password_correct", False): + return st.session_state["username"] + + # Show inputs for username + password. + login_form() + if "password_correct" in st.session_state: + st.error(f"{UNHAPPY_ICON} User not known or password incorrect") + return None diff --git a/bertrend_apps/prospective_demo/dashboard_analysis.py b/bertrend_apps/prospective_demo/dashboard_analysis.py new file mode 100644 index 0000000..a93b782 --- /dev/null +++ b/bertrend_apps/prospective_demo/dashboard_analysis.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +import streamlit as st + + +def dashboard_analysis(): + """Dashboard to analyze information monitoring results""" + + selected_id = st.selectbox( + "SĂ©lection de la veille", options=sorted(st.session_state.user_feeds.keys()) + ) + st.write(selected_id) diff --git a/bertrend_apps/prospective_demo/feeds_config.py b/bertrend_apps/prospective_demo/feeds_config.py new file mode 100644 index 0000000..bdac38d --- /dev/null +++ b/bertrend_apps/prospective_demo/feeds_config.py @@ -0,0 +1,267 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +import re +import time +from pathlib import Path + +import pandas as pd +import streamlit as st +import toml +from loguru import logger + +from bertrend import FEED_BASE_PATH, load_toml_config +from bertrend.config.parameters import LANGUAGES +from bertrend.demos.demos_utils.icons import ( + INFO_ICON, + ERROR_ICON, + ADD_ICON, + EDIT_ICON, + DELETE_ICON, + WARNING_ICON, + TOGGLE_ON_ICON, + TOGGLE_OFF_ICON, +) +from bertrend_apps.common.crontab_utils import ( + get_understandable_cron_description, + check_if_scrapping_active_for_user, + remove_scrapping_for_user, + schedule_scrapping, +) +from bertrend_apps.data_provider import URL_PATTERN +from bertrend_apps.prospective_demo.streamlit_utils import clickable_df + +# Default feed configs +DEFAULT_GNEWS_CRONTAB_EXPRESSION = "1 0 * * 1" +DEFAULT_CUREBOT_CRONTAB_EXPRESSION = "42 0,6,12,18 * * *" # 4 times a day +DEFAULT_MAX_RESULTS = 20 +DEFAULT_NUMBER_OF_DAYS = 14 +FEED_SOURCES = ["google", "curebot"] +TRANSLATION = {"English": "Anglais", "French": "Français"} + +# Feed config path +USER_FEEDS_BASE_PATH = FEED_BASE_PATH / "users" +USER_FEEDS_BASE_PATH.mkdir(parents=True, exist_ok=True) + + +@st.dialog("Configuration d'un nouveau flux de donnĂ©es") +def edit_feed_monitoring(config: dict | None = None): + """Create or update a feed monitoring configuration.""" + chosen_id = st.text_input( + "ID :red[*]", + help="Identifiant du flux de donnĂ©es", + value=None if not config else config["id"], + ) + + provider = st.segmented_control( + "Source", + selection_mode="single", + options=FEED_SOURCES, + default=FEED_SOURCES[0] if not config else config["provider"], + help="SĂ©lection de la source de donnĂ©es", + ) + if provider == "google": + query = st.text_input( + "RequĂȘte :red[*]", + value="" if not config else config["query"], + help="Saisir ici la requĂȘte qui sera faite sur Google News", + ) + language = st.segmented_control( + "Langue", + selection_mode="single", + options=LANGUAGES, + default=LANGUAGES[0], + format_func=lambda lang: TRANSLATION[lang], + help="Choix de la langue", + ) + if "update_frequency" not in st.session_state: + st.session_state.update_frequency = ( + DEFAULT_GNEWS_CRONTAB_EXPRESSION + if not config + else config["update_frequency"] + ) + new_freq = st.text_input( + f"FrĂ©quence d'exĂ©cution", + value=st.session_state.update_frequency, + help=f"FrĂ©quence de collecte des donnĂ©es", + ) + st.session_state.update_frequency = new_freq + st.write(display_crontab_description(st.session_state.update_frequency)) + + elif provider == "curebot": + query = st.text_input( + "ATOM feed :red[*]", + value="" if not config else config["query"], + help="URL du flux de donnĂ©es Curebot", + ) + + try: + get_understandable_cron_description(st.session_state.update_frequency) + valid_cron = True + except: + valid_cron = False + + if st.button( + "OK", + disabled=not chosen_id + or not query + or (query and provider == "curebot" and not re.match(URL_PATTERN, query)), + ): + if not config: + config = {} + config["id"] = "feed_" + chosen_id + config["feed_dir_path"] = st.session_state.username + "/feed_" + chosen_id + config["query"] = query + config["provider"] = provider + if not config.get("max_results"): + config["max_results"] = DEFAULT_MAX_RESULTS + if not config.get("number_of_days"): + config["number_of_days"] = DEFAULT_NUMBER_OF_DAYS + if provider == "google": + config["language"] = "fr" if language == "French" else "en" + config["update_frequency"] = ( + st.session_state.update_frequency + if valid_cron + else DEFAULT_GNEWS_CRONTAB_EXPRESSION + ) + elif provider == "curebot": + config["language"] = "fr" + config["update_frequency"] = DEFAULT_CUREBOT_CRONTAB_EXPRESSION + + if "update_frequency" in st.session_state: + del st.session_state["update_frequency"] # to avoid memory effect + + # Remove prevous crontab if any + remove_scrapping_for_user(feed_id=chosen_id, user=st.session_state.username) + + # Save feed config and update crontab + save_feed_config(chosen_id, config) + + +def save_feed_config(chosen_id, feed_config: dict): + """Save the feed configuration to disk as a TOML file.""" + feed_path = ( + USER_FEEDS_BASE_PATH / st.session_state.username / f"{chosen_id}_feed.toml" + ) + # Save the dictionary to a TOML file + with open(feed_path, "w") as toml_file: + toml.dump({"data-feed": feed_config}, toml_file) + logger.debug(f"Saved feed config {feed_config} to {feed_path}") + schedule_scrapping(feed_path) + st.rerun() + + +def display_crontab_description(crontab_expr: str) -> str: + try: + return f":blue[{INFO_ICON} {get_understandable_cron_description(crontab_expr)}]" + except Exception: + return f":red[{ERROR_ICON} Expression mal Ă©crite !]" + + +def read_user_feeds(username: str) -> tuple[dict[str, dict], dict[str, Path]]: + user_feed_dir = USER_FEEDS_BASE_PATH / username + user_feed_dir.mkdir(parents=True, exist_ok=True) + logger.debug(f"Reading user feeds from: {user_feed_dir}") + matching_files = user_feed_dir.rglob("*_feed.toml") + + user_feeds = {} + feed_files = {} + for f in matching_files: + feed_id = f.name.split("_feed.toml")[0] + user_feeds[feed_id] = load_toml_config(f) + feed_files[feed_id] = f + + return user_feeds, feed_files + + +def configure_information_sources(): + """Configure Information Sources.""" + # if "user_feeds" not in st.session_state: + st.session_state.user_feeds, st.session_state.feed_files = read_user_feeds( + st.session_state.username + ) + + displayed_list = [] + for k, v in st.session_state.user_feeds.items(): + displayed_list.append( + { + "id": k, + "provider": v["data-feed"]["provider"], + "query": v["data-feed"]["query"], + "language": v["data-feed"]["language"], + "update_frequency": v["data-feed"]["update_frequency"], + } + ) + df = pd.DataFrame(displayed_list) + if not df.empty: + df = df.sort_values(by="id", inplace=False).reset_index(drop=True) + + if st.button(ADD_ICON, type="tertiary", help="Nouveau flux de veille"): + edit_feed_monitoring() + + clickable_df_buttons = [ + (EDIT_ICON, edit_feed_monitoring, "secondary"), + (lambda x: toggle_icon(df, x), toggle_feed, "secondary"), + (DELETE_ICON, handle_delete, "primary"), + ] + clickable_df(df, clickable_df_buttons) + + +def toggle_icon(df: pd.DataFrame, index: int) -> str: + """Switch the toggle icon depending on the statis of the scrapping feed in the crontab""" + feed_id = df["id"][index] + return ( + f":green[{TOGGLE_ON_ICON}]" + if check_if_scrapping_active_for_user( + feed_id=feed_id, user=st.session_state.username + ) + else f":red[{TOGGLE_OFF_ICON}]" + ) + + +def toggle_feed(cfg: dict): + """Activate / deactivate the feed from the crontab""" + feed_id = cfg["id"] + if check_if_scrapping_active_for_user( + feed_id=feed_id, user=st.session_state.username + ): + if remove_scrapping_for_user(feed_id=feed_id, user=st.session_state.username): + st.toast(f"Le flux **{feed_id}** est dĂ©activĂ© !", icon=INFO_ICON) + logger.info(f"Flux {feed_id} dĂ©sactivĂ© !") + else: + schedule_scrapping(st.session_state.feed_files[feed_id]) + st.toast(f"Le flux **{feed_id}** est activĂ© !", icon=WARNING_ICON) + logger.info(f"Flux {feed_id} activĂ© !") + time.sleep(0.2) + st.rerun() + + +def delete_feed_config(feed_id: str): + # remove config file + file_path: Path = st.session_state.feed_files[feed_id] + try: + file_path.unlink() + logger.debug(f"Feed file {file_path} has been removed.") + except Exception as e: + logger.error(f"An error occurred: {e}") + + +@st.dialog("Confirmation") +def handle_delete(row_dict: dict): + """Function to handle remove click events""" + feed_id = row_dict["id"] + st.write( + f":orange[{WARNING_ICON}] Voulez-vous vraiment supprimer le flux de veille **{feed_id}** ?" + ) + col1, col2, _ = st.columns([2, 2, 8]) + with col1: + if st.button("Oui", type="primary"): + remove_scrapping_for_user(feed_id=feed_id, user=st.session_state.username) + delete_feed_config(feed_id) + logger.info(f"Flux {feed_id} supprimĂ© !") + time.sleep(0.2) + st.rerun() + with col2: + if st.button("Non"): + st.rerun() diff --git a/bertrend_apps/prospective_demo/feeds_data.py b/bertrend_apps/prospective_demo/feeds_data.py new file mode 100644 index 0000000..c195e7d --- /dev/null +++ b/bertrend_apps/prospective_demo/feeds_data.py @@ -0,0 +1,157 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +import os +import datetime +from pathlib import Path + +import jsonlines +import pandas as pd +import streamlit as st +from loguru import logger + +from bertrend import FEED_BASE_PATH +from bertrend.utils.data_loading import load_data, TIMESTAMP_COLUMN + + +def display_data_status(): + st.selectbox( + "SĂ©lection de la veille", + options=sorted(st.session_state.user_feeds.keys()), + key="id_data", + ) + + if "time_window" not in st.session_state: + st.session_state.time_window = 7 + st.slider( + "FenĂȘtre temporelle (jours)", + min_value=1, + max_value=60, + step=1, + key="time_window", + ) + + display_data_info_for_feed(st.session_state.id_data) + + +def display_data_info_for_feed(feed_id: str): + all_files = get_all_files_for_feed(feed_id) + df = get_all_data(files=all_files) + + if df.empty: + df_filtered = pd.DataFrame() + else: + cutoff_date = datetime.datetime.now() - datetime.timedelta( + days=st.session_state.time_window + ) + df_filtered = df[df[TIMESTAMP_COLUMN] >= cutoff_date] + + number_articles = count_lines_jsonl(all_files) + + stats = { + "ID": feed_id, + "# Fichiers": len(all_files), + "Date dĂ©but": df[TIMESTAMP_COLUMN].min() if not df.empty else None, + "Date fin": df[TIMESTAMP_COLUMN].max() if not df.empty else None, + "Nombre d'articles": number_articles, + f"Nombre d'articles (derniers {st.session_state.time_window} jours)": len( + df_filtered + ), + } + + st.dataframe(pd.DataFrame([stats])) + + st.write(f"#### DonnĂ©es des derniers {st.session_state.time_window} jours") + st.dataframe(df_filtered, use_container_width=True) + + +def count_lines_jsonl(filepaths: list[Path]): + """Counts the total number of lines (JSON objects) across multiple JSONL files. + Args: + filepaths: A list of Path objects representing the paths to the JSONL files. + Returns: + The total number of lines across all files. + """ + total_count = 0 + for filepath in filepaths: + try: + with jsonlines.open(filepath) as reader: + count = 0 + for _ in reader: + count += 1 + total_count += count + except FileNotFoundError: + logger.warning(f"Error: File not found: {filepath}") + except Exception as e: + logger.warning(f"An error occurred while processing {filepath}: {e}") + + return total_count + + +@st.cache_data +def get_all_data(files: list[Path]) -> pd.DataFrame: + """Returns the data contained in the provided files as a single DataFrame.""" + if not files: + return pd.DataFrame() + dfs = [load_data(Path(f)) for f in files] + new_df = pd.concat(dfs).drop_duplicates( + subset=["title"], keep="first", inplace=False + ) + return new_df + + +def get_all_files_for_feed(feed_id: str) -> list[Path]: + """Returns the paths of all files associated to a feed for the current user.""" + feed_base_dir = st.session_state.user_feeds[feed_id]["data-feed"]["feed_dir_path"] + list_all_files = list( + Path(FEED_BASE_PATH, feed_base_dir).glob( + f"*{st.session_state.user_feeds[feed_id]['data-feed'].get('id')}*.jsonl*" + ) + ) + return list_all_files + + +def get_last_files(files: list[Path], time_window: int) -> list[Path] | None: + """Returns the paths of all files associated to a feed for the current user in the last time window.""" + cutoff_date = datetime.datetime.now() - datetime.timedelta(days=time_window) + matching_files = [] + for file in files: + try: + file_stat = file.stat() # Get file stats only once + print(file_stat) + file_time = datetime.datetime.fromtimestamp(file_stat.st_mtime) + if file_time >= cutoff_date: + matching_files.append(file) + except OSError as e: + logger.warning(f"Error accessing file {file}: {e}") + # Handle the error as needed (e.g., skip the file) + + +def get_first_file(files: list[Path]) -> Path | None: + """Returns the first file associated to a feed for the current user.""" + if files: # Check if any files were found + first_file = min(files, key=os.path.getctime) + else: + first_file = None # Or handle the case where no files are found appropriately. Perhaps raise an exception. + return first_file + + +def get_last_file(files: list[Path]) -> Path | None: + """Returns the last file associated to a feed for the current user.""" + if files: # Check if any files were found + latest_file = max(files, key=os.path.getctime) + else: + latest_file = None # Or handle the case where no files are found appropriately. Perhaps raise an exception. + return latest_file + + # check feed config for path + + # select time window in days + # display + # latest data date + # last data date + # number of files + # number of articles + # list of files last x days + # number of articles last x days diff --git a/bertrend_apps/prospective_demo/start_demo.sh b/bertrend_apps/prospective_demo/start_demo.sh new file mode 100755 index 0000000..c150877 --- /dev/null +++ b/bertrend_apps/prospective_demo/start_demo.sh @@ -0,0 +1,9 @@ +# +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +# + +# Starts the Trend Analysis application +CUDA_VISIBLE_DEVICES=0 streamlit run --theme.primaryColor royalblue app.py \ No newline at end of file diff --git a/bertrend_apps/prospective_demo/streamlit_utils.py b/bertrend_apps/prospective_demo/streamlit_utils.py new file mode 100644 index 0000000..a5f0268 --- /dev/null +++ b/bertrend_apps/prospective_demo/streamlit_utils.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024, RTE (https://www.rte-france.com) +# See AUTHORS.txt +# SPDX-License-Identifier: MPL-2.0 +# This file is part of BERTrend. +from typing import Callable + +import pandas as pd +import streamlit as st + + +def clickable_df( + df: pd.DataFrame, clickable_buttons: list[tuple[str | Callable, Callable, str]] +): + """Streamlit display of a df-like rendering with additional clickable columns (buttons).""" + cols = st.columns(len(df.columns) * [3] + len(clickable_buttons) * [1]) + for i, c in enumerate(df.columns): + with cols[i]: + st.write(f"**{c}**") + for index, row in df.iterrows(): + # Create a clickable container for each row + cols = st.columns(len(df.columns) * [3] + len(clickable_buttons) * [1]) + for i, col in enumerate(cols[: -len(clickable_buttons)]): + with col: + st.write(row[df.columns[i]]) + # Render the additional columns (clickable) + for i, button in enumerate(clickable_buttons): + with cols[len(df.columns) + i]: + button_label = button[0](index) if callable(button[0]) else button[0] + if st.button(button_label, key=f"button{i}_{index}", type=button[2]): + button[1](df.iloc[index].to_dict()) diff --git a/pyproject.toml b/pyproject.toml index 9b5aae9..d5d4c77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ ipython = "^8.28.0" #accelerate = "^0.34.2" # required? bertopic = "0.16.2" black = "^24.10.0" +cron-descriptor = "^1.4.5" datamapplot = "0.3.0" dateparser = "^1.2.0" dask = "2024.12.0" # issues with >=2025.x (https://github.com/dask/dask/issues/11678)