From b0c5dc163488428c6a92ee50c9cc5344bb1360cc Mon Sep 17 00:00:00 2001 From: varun646 Date: Tue, 25 Feb 2025 15:13:36 -0500 Subject: [PATCH] fix tests, clean code fix tests, clean up overrides --- src/PatientX/models/BERTopicModel.py | 36 ++-------------------------- src/PatientX/run.py | 31 ++++++++++++------------ tests/test_run.py | 8 +++++-- 3 files changed, 24 insertions(+), 51 deletions(-) diff --git a/src/PatientX/models/BERTopicModel.py b/src/PatientX/models/BERTopicModel.py index 083da06..5330d99 100644 --- a/src/PatientX/models/BERTopicModel.py +++ b/src/PatientX/models/BERTopicModel.py @@ -72,37 +72,6 @@ def _save_representative_docs(self, documents: pd.DataFrame): ) self.representative_docs_ = repr_docs - @override - def _extract_topics( - self, - documents: pd.DataFrame, - embeddings: np.ndarray = None, - mappings=None, - verbose: bool = False, - ): - """Extract topics from the clusters using a class-based TF-IDF. - - Arguments: - documents: Dataframe with documents and their corresponding IDs - embeddings: The document embeddings - mappings: The mappings from topic to word - verbose: Whether to log the process of extracting topics - - Returns: - c_tf_idf: The resulting matrix giving a value (importance score) for each word per topic - """ - if verbose: - logger.info("Representation - Extracting topics from clusters using representation models.") - print("Documents DF passed in: ") - print(documents.head(10)) - documents_per_topic = documents.groupby(["Topic"], as_index=False).agg({"Document": " ".join}) - self.c_tf_idf_, words = self._c_tf_idf(documents_per_topic) - self.topic_representations_ = self._extract_words_per_topic(words, documents) - self._create_topic_vectors(documents=documents, embeddings=embeddings, mappings=mappings) - self.bertopic_only_results = documents_per_topic.copy(deep=True) - if verbose: - logger.info("Representation - Completed \u2713") - @override def getClusters(self, datapath): if Path(datapath).is_file(): @@ -134,9 +103,8 @@ def visualizeModel(self): # visualize term rank super().visualize_term_rank() - def get_bertopic_only_results(self) -> tuple[ - Any, Any, dict[int, list[tuple[str | list[str], Any] | tuple[str, float]]]]: - return (self.bertopic_only_results, self.representative_docs_, self.bertopic_representative_words) + def get_bertopic_only_results(self) -> tuple[Any, dict[int, list[tuple[str | list[str], Any] | tuple[str, float]]]]: + return self.representative_docs_, self.bertopic_representative_words @override def _extract_words_per_topic( diff --git a/src/PatientX/run.py b/src/PatientX/run.py index 32187b2..3182c0b 100644 --- a/src/PatientX/run.py +++ b/src/PatientX/run.py @@ -113,7 +113,7 @@ def get_clustering_model(clustering_model: ClusteringModel) -> Optional[ClusterM def run_bertopic_model(documents: List[str], embeddingspath: Path, dimensionality_reduction: DimensionalityReduction, clustering_model: ClusteringModel, representationmodel: RepresentationModel, min_topic_size: int, nr_docs: int, document_diversity: float, low_memory: bool) -> tuple[ - DataFrame, ndarray | Any, tuple[Any, Any, dict[int, list[tuple[str | list[str], Any] | tuple[str, float]]]]]: + DataFrame, ndarray | Any, tuple[Any, dict[int, list[tuple[str | list[str], Any] | tuple[str, float]]]]]: """ Run the bertopic model on the given documents with the given model parameters @@ -178,11 +178,23 @@ def run_bertopic_model(documents: List[str], embeddingspath: Path, dimensionalit results_df = pd.concat([results_df, rep_docs_df], axis=1) + return results_df, document_embeddings, bertopic_model.get_bertopic_only_results() +def format_bertopic_results(results_df, representative_docs, bertopic_representative_words): + get_words = lambda xs: str([x[0] for x in xs]) + bertopic_representative_words = {k: get_words(v) for k, v in bertopic_representative_words.items()} + counts = results_df['Count'] + topics = results_df['Topic'] - return results_df, document_embeddings, bertopic_model.get_bertopic_only_results() + bertopic_results_df = pd.DataFrame.from_dict(representative_docs, orient='index') + rep_words = pd.DataFrame.from_dict(bertopic_representative_words, orient='index') + bertopic_final_res = pd.concat( + [topics.reset_index(), counts.reset_index(), rep_words.reset_index(), bertopic_results_df.reset_index()], + axis=1) + + return bertopic_final_res @app.command() @use_yaml_config() @@ -237,20 +249,9 @@ def main( nr_docs=nr_docs, document_diversity=document_diversity) results_df.to_csv(resultpath / "output.csv", index=False) - bertopic_df, representative_docs, bertopic_representative_words = bertopic_only_results - - get_words = lambda xs : str([x[0] for x in xs]) - bertopic_representative_words = {k: get_words(v) for k, v in bertopic_representative_words.items()} - - counts = results_df['Count'] - topics = results_df['Topic'] - - - bertopic_results_df = pd.DataFrame.from_dict(representative_docs, orient='index') - rep_words = pd.DataFrame.from_dict(bertopic_representative_words, orient='index') - + representative_docs, bertopic_representative_words = bertopic_only_results - bertopic_final_res = pd.concat([topics.reset_index(), counts.reset_index(), rep_words.reset_index(), bertopic_results_df.reset_index()], axis=1) + bertopic_final_res = format_bertopic_results(results_df, representative_docs, bertopic_representative_words) bertopic_final_res.to_csv(resultpath / "bertopic_final_results.csv", index=False) diff --git a/tests/test_run.py b/tests/test_run.py index 5e220cc..82e3aa1 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -38,7 +38,8 @@ @pytest.mark.parametrize("document_diversity", document_diversity_values) def test_run_to_completion(fs, dimensionality_reduction_model, clustering_model, save_embeddings, document_diversity): with patch("PatientX.run.get_representation_model", return_value=None) as mock_representation_model, \ - patch("PatientX.run.run_bertopic_model", return_value=(result_df, [1,2,3])) as mock_bertopic: + patch("PatientX.run.run_bertopic_model", return_value=(pd.DataFrame(), pd.DataFrame(), (pd.DataFrame(), pd.DataFrame()))) as mock_bertopic, \ + patch("PatientX.run.format_bertopic_results", return_value=pd.DataFrame()) as mock_bertopic_output: repo_root = Path(__file__).parent.parent output_dir = Path("test_output") fs.create_dir(output_dir) @@ -49,6 +50,8 @@ def test_run_to_completion(fs, dimensionality_reduction_model, clustering_model, output_file = output_dir / "output.csv" embeddings_file = output_dir / "embeddings.pkl" + bertopic_output_file = output_dir / "bertopic_final_results.csv" + if save_embeddings: result = CliRunner().invoke(app, ["--datapath", input_dir, "--resultpath", output_dir, "--min-topic-size", 10, "--document-diversity", document_diversity, "--save-embeddings"]) else: @@ -56,8 +59,9 @@ def test_run_to_completion(fs, dimensionality_reduction_model, clustering_model, ["--datapath", input_dir, "--resultpath", output_dir, "--min-topic-size", 10, "--document-diversity", document_diversity, "--no-save-embeddings"]) - assert embeddings_file.exists() == save_embeddings assert result.exit_code == 0 + assert embeddings_file.exists() == save_embeddings + assert bertopic_output_file.exists() assert output_file.exists() def test_read_csv_files_in_directory(fs):