Skip to content

Commit

Permalink
fix tests, clean code
Browse files Browse the repository at this point in the history
fix tests, clean up overrides
  • Loading branch information
varun646 committed Feb 25, 2025
1 parent 8152e29 commit b0c5dc1
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 51 deletions.
36 changes: 2 additions & 34 deletions src/PatientX/models/BERTopicModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,37 +72,6 @@ def _save_representative_docs(self, documents: pd.DataFrame):
)
self.representative_docs_ = repr_docs

@override
def _extract_topics(
self,
documents: pd.DataFrame,
embeddings: np.ndarray = None,
mappings=None,
verbose: bool = False,
):
"""Extract topics from the clusters using a class-based TF-IDF.
Arguments:
documents: Dataframe with documents and their corresponding IDs
embeddings: The document embeddings
mappings: The mappings from topic to word
verbose: Whether to log the process of extracting topics
Returns:
c_tf_idf: The resulting matrix giving a value (importance score) for each word per topic
"""
if verbose:
logger.info("Representation - Extracting topics from clusters using representation models.")
print("Documents DF passed in: ")
print(documents.head(10))
documents_per_topic = documents.groupby(["Topic"], as_index=False).agg({"Document": " ".join})
self.c_tf_idf_, words = self._c_tf_idf(documents_per_topic)
self.topic_representations_ = self._extract_words_per_topic(words, documents)
self._create_topic_vectors(documents=documents, embeddings=embeddings, mappings=mappings)
self.bertopic_only_results = documents_per_topic.copy(deep=True)
if verbose:
logger.info("Representation - Completed \u2713")

@override
def getClusters(self, datapath):
if Path(datapath).is_file():
Expand Down Expand Up @@ -134,9 +103,8 @@ def visualizeModel(self):
# visualize term rank
super().visualize_term_rank()

def get_bertopic_only_results(self) -> tuple[
Any, Any, dict[int, list[tuple[str | list[str], Any] | tuple[str, float]]]]:
return (self.bertopic_only_results, self.representative_docs_, self.bertopic_representative_words)
def get_bertopic_only_results(self) -> tuple[Any, dict[int, list[tuple[str | list[str], Any] | tuple[str, float]]]]:
return self.representative_docs_, self.bertopic_representative_words

@override
def _extract_words_per_topic(
Expand Down
31 changes: 16 additions & 15 deletions src/PatientX/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_clustering_model(clustering_model: ClusteringModel) -> Optional[ClusterM
def run_bertopic_model(documents: List[str], embeddingspath: Path, dimensionality_reduction: DimensionalityReduction,
clustering_model: ClusteringModel, representationmodel: RepresentationModel, min_topic_size: int,
nr_docs: int, document_diversity: float, low_memory: bool) -> tuple[
DataFrame, ndarray | Any, tuple[Any, Any, dict[int, list[tuple[str | list[str], Any] | tuple[str, float]]]]]:
DataFrame, ndarray | Any, tuple[Any, dict[int, list[tuple[str | list[str], Any] | tuple[str, float]]]]]:
"""
Run the bertopic model on the given documents with the given model parameters
Expand Down Expand Up @@ -178,11 +178,23 @@ def run_bertopic_model(documents: List[str], embeddingspath: Path, dimensionalit

results_df = pd.concat([results_df, rep_docs_df], axis=1)

return results_df, document_embeddings, bertopic_model.get_bertopic_only_results()

def format_bertopic_results(results_df, representative_docs, bertopic_representative_words):
get_words = lambda xs: str([x[0] for x in xs])
bertopic_representative_words = {k: get_words(v) for k, v in bertopic_representative_words.items()}

counts = results_df['Count']
topics = results_df['Topic']

return results_df, document_embeddings, bertopic_model.get_bertopic_only_results()
bertopic_results_df = pd.DataFrame.from_dict(representative_docs, orient='index')
rep_words = pd.DataFrame.from_dict(bertopic_representative_words, orient='index')

bertopic_final_res = pd.concat(
[topics.reset_index(), counts.reset_index(), rep_words.reset_index(), bertopic_results_df.reset_index()],
axis=1)

return bertopic_final_res

@app.command()
@use_yaml_config()
Expand Down Expand Up @@ -237,20 +249,9 @@ def main(
nr_docs=nr_docs, document_diversity=document_diversity)
results_df.to_csv(resultpath / "output.csv", index=False)

bertopic_df, representative_docs, bertopic_representative_words = bertopic_only_results

get_words = lambda xs : str([x[0] for x in xs])
bertopic_representative_words = {k: get_words(v) for k, v in bertopic_representative_words.items()}

counts = results_df['Count']
topics = results_df['Topic']


bertopic_results_df = pd.DataFrame.from_dict(representative_docs, orient='index')
rep_words = pd.DataFrame.from_dict(bertopic_representative_words, orient='index')

representative_docs, bertopic_representative_words = bertopic_only_results

bertopic_final_res = pd.concat([topics.reset_index(), counts.reset_index(), rep_words.reset_index(), bertopic_results_df.reset_index()], axis=1)
bertopic_final_res = format_bertopic_results(results_df, representative_docs, bertopic_representative_words)
bertopic_final_res.to_csv(resultpath / "bertopic_final_results.csv", index=False)


Expand Down
8 changes: 6 additions & 2 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
@pytest.mark.parametrize("document_diversity", document_diversity_values)
def test_run_to_completion(fs, dimensionality_reduction_model, clustering_model, save_embeddings, document_diversity):
with patch("PatientX.run.get_representation_model", return_value=None) as mock_representation_model, \
patch("PatientX.run.run_bertopic_model", return_value=(result_df, [1,2,3])) as mock_bertopic:
patch("PatientX.run.run_bertopic_model", return_value=(pd.DataFrame(), pd.DataFrame(), (pd.DataFrame(), pd.DataFrame()))) as mock_bertopic, \
patch("PatientX.run.format_bertopic_results", return_value=pd.DataFrame()) as mock_bertopic_output:
repo_root = Path(__file__).parent.parent
output_dir = Path("test_output")
fs.create_dir(output_dir)
Expand All @@ -49,15 +50,18 @@ def test_run_to_completion(fs, dimensionality_reduction_model, clustering_model,
output_file = output_dir / "output.csv"
embeddings_file = output_dir / "embeddings.pkl"

bertopic_output_file = output_dir / "bertopic_final_results.csv"

if save_embeddings:
result = CliRunner().invoke(app, ["--datapath", input_dir, "--resultpath", output_dir, "--min-topic-size", 10, "--document-diversity", document_diversity, "--save-embeddings"])
else:
result = CliRunner().invoke(app,
["--datapath", input_dir, "--resultpath", output_dir, "--min-topic-size", 10,
"--document-diversity", document_diversity, "--no-save-embeddings"])

assert embeddings_file.exists() == save_embeddings
assert result.exit_code == 0
assert embeddings_file.exists() == save_embeddings
assert bertopic_output_file.exists()
assert output_file.exists()

def test_read_csv_files_in_directory(fs):
Expand Down

0 comments on commit b0c5dc1

Please sign in to comment.