diff --git a/bertrend/BERTrend.py b/bertrend/BERTrend.py index 42d4679..ce2a4e5 100644 --- a/bertrend/BERTrend.py +++ b/bertrend/BERTrend.py @@ -649,11 +649,13 @@ def save_models(self, models_path: Path = MODELS_DIR): topic_model.topic_info_df.to_pickle(model_dir / TOPIC_INFO_DF_FILE) # Serialize BERTrend (excluding topic models for separate reuse if needed) - topic_models_bak = copy.deepcopy(self.topic_models) - self.topic_models = None + # topic_models_bak = copy.deepcopy(self.topic_models) + # FIXME: the commented code introduced a too-heavy memory overhead, to be improved; the idea is to serialize + # the topics models separetely from the rest of the BERTrend object + # self.topic_models = None with open(models_path / BERTREND_FILE, "wb") as f: dill.dump(self, f) - self.topic_models = topic_models_bak + # self.topic_models = topic_models_bak logger.info(f"Models saved to: {models_path}") diff --git a/bertrend_apps/prospective_demo/dashboard_common.py b/bertrend_apps/prospective_demo/dashboard_common.py index 5d67cc5..22bcac6 100644 --- a/bertrend_apps/prospective_demo/dashboard_common.py +++ b/bertrend_apps/prospective_demo/dashboard_common.py @@ -29,6 +29,14 @@ def get_df_topics(model_interpretation_path=None) -> dict[str, pd.DataFrame]: def update_key(key: str, new_value: Any): st.session_state[key] = new_value + # reset ts value in order to avoid errors if a previous ts value is not available for the new key value + + +def update_key_and_ts(key: str, new_value: Any): + update_key(key, new_value) + # reset ts value in order to avoid errors if a previous ts value is not available for the new key value + if "reference_ts" in st.session_state: + del st.session_state["reference_ts"] def choose_id_and_ts(): @@ -43,7 +51,9 @@ def choose_id_and_ts(): options=options, index=options.index(st.session_state.model_id), key=model_id_key, # to avoid pb of unicity if displayed on several places - on_change=lambda: update_key("model_id", st.session_state[model_id_key]), + on_change=lambda: update_key_and_ts( + "model_id", st.session_state[model_id_key] + ), ) with col2: list_models = get_models_info(model_id) @@ -58,7 +68,7 @@ def choose_id_and_ts(): if "reference_ts" not in st.session_state: st.session_state.reference_ts = list_models[-1] ts_key = uuid.uuid4() - reference_ts = st.select_slider( + st.select_slider( "Date d'analyse", options=list_models, value=st.session_state.reference_ts, diff --git a/bertrend_apps/prospective_demo/models_info.py b/bertrend_apps/prospective_demo/models_info.py index 9a98783..b3099b4 100644 --- a/bertrend_apps/prospective_demo/models_info.py +++ b/bertrend_apps/prospective_demo/models_info.py @@ -222,7 +222,7 @@ def handle_regenerate_models(row_dict: dict): with col1: if yes_btn := st.button("Oui", type="primary"): # Delete previously stored model - # delete_cached_models(model_id) + delete_cached_models(model_id) logger.info(f"Modèles en cache supprimés pour la veille {model_id} !") # Regenerate new models