-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from gt-sse-center/583-varun--add-bertopic-support
[583] varun add bertopic support
- Loading branch information
Showing
4 changed files
with
392 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,355 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"metadata": {}, | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# BERTopic\n", | ||
"Using BERTopic to identify topics in dementia forum text. Each iteration adds a new level to the model. BERTopic has multiple fully customizable steps to it and each iteration explores with different parts of the model's pipeline\n", | ||
"\n", | ||
"" | ||
], | ||
"id": "6acf4b753ba73fa3" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## Data Setup\n", | ||
"Read data into a list where each document is an item in the list" | ||
], | ||
"id": "a73f17e585d36ed5" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"# Read documents from the file\n", | ||
"# corpus_threads_combined.txt contains all dementia forum data\n", | ||
"# Each thread in the forum is represented as a document and separated by a new line\n", | ||
"\n", | ||
"with open('../data/corpus_threads_combined.txt', 'r', encoding='utf-8') as file:\n", | ||
" documents = file.read().split('\\n') # Split on newline to get individual documents" | ||
], | ||
"id": "4edbbb3f496cb603" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"# install bertopic\n", | ||
"!pip install bertopic" | ||
], | ||
"id": "7f5fbd85cce2e908" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## Approach 1: \n", | ||
"- **Embedding Model:** [all-MiniLM-L6-v2 Sentence Transformer](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)\n", | ||
"- **Dimensionality Reduction:** UMAP\n", | ||
"- **Clustering:** HDBScan\n", | ||
"- **Tokenizer:** *None*\n", | ||
"- **Weighting Scheme:** *None*\n", | ||
"- **Representation Tuning:** *None*" | ||
], | ||
"id": "161c6d02f49cd27f" | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "initial_id", | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from bertopic import BERTopic\n", | ||
"from sentence_transformers import SentenceTransformer\n", | ||
"\n", | ||
"# Initialize a sentence transformer model for embeddings\n", | ||
"embedding_model = SentenceTransformer(\"all-MiniLM-L6-v2\")\n", | ||
"\n", | ||
"# Create a BERTopic model\n", | ||
"topic_model = BERTopic(embedding_model=embedding_model, verbose=True)\n", | ||
"\n", | ||
"# Fit the model on the documents\n", | ||
"topics, probs = topic_model.fit_transform(documents)" | ||
] | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"# Show results and inter-topic distance map visualization\n", | ||
"print(topic_model.get_topic_info())\n", | ||
"topic_model.visualize_topics()" | ||
], | ||
"id": "9de98bfff7d5a1d9" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## Approach 2: additional stop word removal\n", | ||
"- **Embedding Model:** [all-MiniLM-L6-v2 Sentence Transformer](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)\n", | ||
"- **Dimensionality Reduction:** UMAP\n", | ||
"- **Clustering:** HDBScan\n", | ||
"- **Tokenizer:** CountVectorizer\n", | ||
"- **Weighting Scheme:** *None*\n", | ||
"- **Representation Tuning:** *None*\n", | ||
"### Clean up data\n", | ||
"Remove some custom stop words not in the existing spacy model's English stop words " | ||
], | ||
"id": "713008a010b293ae" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"# remove custom stop words that aren't caught by spacy's model\n", | ||
"from spacy.lang.en import stop_words\n", | ||
"\n", | ||
"stop_words = list(stop_words.STOP_WORDS)\n", | ||
"custom_stop_words = ['with', 'my', 'your', 'she', 'this', 'was', 'her', 'have', 'as', 'he', 'him', 'but', 'not', 'so', 'are', 'at', 'be', 'has', 'do', 'got', 'how', 'on', 'or', 'would', 'will', 'what', 'they', 'if', 'or', 'get', 'can', 'we', 'me', 'can', 'has', 'his', 'there', 'them', 'just', 'am', 'by', 'that', 'from', 'it', 'is', 'in', 'you', 'also', 'very', 'had', 'a', 'an', 'for']\n", | ||
"\n", | ||
"stop_words += custom_stop_words" | ||
], | ||
"id": "11755848b38c073c" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"from sklearn.feature_extraction.text import CountVectorizer\n", | ||
"\n", | ||
"vectorizer_model = CountVectorizer(stop_words=custom_stop_words)\n", | ||
"topic_model_2 = BERTopic(vectorizer_model=vectorizer_model, embedding_model=embedding_model, verbose=True)" | ||
], | ||
"id": "e05ec5d2ac8741c4" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"# Fit the BERTopic model to the documents\n", | ||
"topics_2, probs_2 = topic_model_2.fit_transform(documents)" | ||
], | ||
"id": "ae9a1c8b5e0c2487" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"# Print the topic information\n", | ||
"print(topic_model_2.get_topic_info())\n", | ||
"\n", | ||
"# visualize inter-topic distance map\n", | ||
"topic_model_2.visualize_topics()" | ||
], | ||
"id": "f1eadc8ecb660be3" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## Approach 3: c-TF-IDF weighting scheme\n", | ||
"- **Embedding Model:** [all-MiniLM-L6-v2 Sentence Transformer](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2)\n", | ||
"- **Dimensionality Reduction:** UMAP\n", | ||
"- **Clustering:** HDBScan\n", | ||
"- **Tokenizer:** CountVectorizer\n", | ||
"- **Weighting Scheme:** c-TF-IDF Transformer\n", | ||
"- **Representation Tuning:** *none*" | ||
], | ||
"id": "a38c6ecc23ab467a" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"from bertopic.vectorizers import ClassTfidfTransformer\n", | ||
"\n", | ||
"ctfidf_model = ClassTfidfTransformer(reduce_frequent_words=True)\n", | ||
"topic_model_3 = BERTopic(ctfidf_model=ctfidf_model, embedding_model=embedding_model, verbose=True, min_topic_size=100, vectorizer_model=vectorizer_model)" | ||
], | ||
"id": "9c13ad8cf1651e57" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"# Fit the BERTopic model to the documents\n", | ||
"topics_3, probs_3 = topic_model_3.fit_transform(documents)" | ||
], | ||
"id": "22fca7045b1ef800" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"# Print the topic information\n", | ||
"print(topic_model_3.get_topic_info())\n", | ||
"\n", | ||
"# visualize inter-topic distance map\n", | ||
"topic_model_3.visualize_topics()" | ||
], | ||
"id": "a858cc4f383f5cb7" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## Approach 4: updated embedding model\n", | ||
"- **Embedding Model:** [pritamdeka/S-PubMedBert-MS-MARCO](https://huggingface.co/pritamdeka/S-PubMedBert-MS-MARCO)\n", | ||
"- **Dimensionality Reduction:** UMAP\n", | ||
"- **Clustering:** HDBScan\n", | ||
"- **Tokenizer:** CountVectorizer\n", | ||
"- **Weighting Scheme:** c-TF-IDF\n", | ||
"- **Representation Tuning:** *none*" | ||
], | ||
"id": "95406de44b3f9f21" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"# Initialize BERTopic with a sentence transformer fine-tuned on medical text for embeddings\n", | ||
"medical_embedding_model = SentenceTransformer('pritamdeka/S-PubMedBert-MS-MARCO')\n", | ||
"\n", | ||
"# Note: we tried using the embedding model below but came out with far worse results than the embedding model above\n", | ||
"# nvidia_embedding_model = SentenceTransformer('dunzhang/stella_en_1.5B_v5')\n", | ||
"\n", | ||
"topic_model_4 = BERTopic(ctfidf_model=ctfidf_model, embedding_model=medical_embedding_model, verbose=True, min_topic_size=100, vectorizer_model=vectorizer_model)\n", | ||
"\n", | ||
"\n", | ||
"# Fit the BERTopic model to the documents\n", | ||
"topics_4, probs_4 = topic_model_4.fit_transform(documents)" | ||
], | ||
"id": "9b67871ad3558689" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"# Print the topic information\n", | ||
"print(topic_model_4.get_topic_info())\n", | ||
"\n", | ||
"# visualize inter-topic distance map\n", | ||
"topic_model_4.visualize_topics()" | ||
], | ||
"id": "bf71e417e46be5b8" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "markdown", | ||
"source": [ | ||
"## Approach 5: adding KeyBERT representation model\n", | ||
"- **Embedding Model:** [pritamdeka/S-PubMedBert-MS-MARCO](https://huggingface.co/pritamdeka/S-PubMedBert-MS-MARCO)\n", | ||
"- **Dimensionality Reduction:** UMAP\n", | ||
"- **Clustering:** HDBScan\n", | ||
"- **Tokenizer:** CountVectorizer\n", | ||
"- **Weighting Scheme:** c-TF-IDF\n", | ||
"- **Representation Tuning:** KeyBERT" | ||
], | ||
"id": "9899025fb0a6e765" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"from bertopic.representation import KeyBERTInspired\n", | ||
"\n", | ||
"# Create your representation model\n", | ||
"representation_model = KeyBERTInspired()\n", | ||
"\n", | ||
"topic_model_5 = BERTopic(ctfidf_model=ctfidf_model, embedding_model=medical_embedding_model, verbose=True, min_topic_size=100, vectorizer_model=vectorizer_model, representation_model=representation_model)\n", | ||
"\n", | ||
"\n", | ||
"# Fit the BERTopic model to the documents\n", | ||
"topics_5, probs_5 = topic_model_5.fit_transform(documents)" | ||
], | ||
"id": "195cac749a7374b9" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"# Print the topic information\n", | ||
"print(topic_model_5.get_topic_info())\n", | ||
"\n", | ||
"# visualize inter-topic distance map\n", | ||
"topic_model_5.visualize_topics()" | ||
], | ||
"id": "4b6e4aa4ce896057" | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"outputs": [], | ||
"execution_count": null, | ||
"source": [ | ||
"# Initialize BERTopic with a sentence transformer for embeddings\n", | ||
"nvidia_embedding_model = SentenceTransformer('dunzhang/stella_en_1.5B_v5')\n", | ||
"topic_model_nvidia = BERTopic(ctfidf_model=ctfidf_model, embedding_model=nvidia_embedding_model, verbose=True, min_topic_size=100, vectorizer_model=vectorizer_model)\n", | ||
"\n", | ||
"\n", | ||
"# Fit the BERTopic model to the documents\n", | ||
"topics_nvidia, probs_nvidia = topic_model_nvidia.fit_transform(documents)\n", | ||
"\n", | ||
"# Print the topic information\n", | ||
"print(topic_model_nvidia.get_topic_info())" | ||
], | ||
"id": "24c7360e1dc51e90" | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.