Skip to content

Commit

Permalink
Merge pull request #20 from gt-sse-center/607-Robin--enable-context-s…
Browse files Browse the repository at this point in the history
…upport

[607] update notebook to include more dependency installed and prompt updated
  • Loading branch information
varun646 authored Jan 17, 2025
2 parents 21e6ceb + 9e0d28d commit c7a7841
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 65 deletions.
93 changes: 93 additions & 0 deletions python-notebooks/MistralNotebook_README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Documentation for LLM Integration in PatientX.AI

This document explains how to utilize the LLM-related components in the **PatientX.AI** project, specifically focusing on:

1. **Using BERTopic with LLMs**
2. **Understanding the `MistralRepresentation` Class and Customizing Prompts**
4. **Interacting with Chat/Generate APIs**

---

## 1. Using BERTopic with LLMs

The project integrates BERTopic to perform topic modeling. Here's a brief overview of the setup in the `bertopic.ipynb` notebook:

### Prerequisites

Ensure you have the required libraries installed:
```bash
# install the following packages, depending on your system, you could use regular pip
pip install --upgrade pip
pip3 install numpy==1.24.4
pip3 install bertopic
pip3 install spacy
pip3 install datamapplot
pip3 install "nbformat>=4.2.0"
pip3 install --upgrade nbformat
pip3 install ipykernel

```
---

## Advanced Usage of `MistralRepresentation`

This section provides deeper insights into using the `MistralRepresentation` class for advanced tasks.

### 1. Streaming Responses from APIs

The `stream_response` method allows you to handle responses incrementally when working with APIs that support streaming. This is particularly useful when generating lengthy responses.

#### Example:
```python
url = "http://127.0.0.1:11434/api/generate"
payload = {
"model": "mistral-small",
"messages": [{"role": "user", "content": "What is the capital of France?"}],
"prompt": "What is the capital of France?"
}
#using "messages" and ""prompt"" is compatible with api/generate and api/chat

response = mistral_representation.stream_response(url, payload)
print(response)
```

## 2. Customizing Prompts

Prompts are central to how the LLM interprets the input.

#### Default Prompts

The default prompts are defined as constants in the `MistralRepresentation` class. You can modify the prompts as needed for your specific use case.:

• DEFAULT_PROMPT:
```python
Here are documents:
[DOCUMENTS]
The topic is described by the following keywords: [KEYWORDS]
I need you to write "The topic is:" then print a short description of the documents in markdown format.
```

Using the chat API, the default prompt composed of the following elements:

• DEFAULT_PROMPT_CHAT_START:
• DEFAULT_PROMPT_CHAT_CONTEXT:
• DEFAULT_PROMPT_CHAT_END:

Start will be called first and then context for each document and end will be called to ask for the topic.

---

## Conclusion

The integration of `MistralRepresentation` within the PatientX.AI project provides a robust framework for leveraging LLMs in topic modeling and text analysis. By utilizing customizable prompts, flexible API configurations, and advanced handling of documents, this implementation allows for dynamic and accurate representations of text-based data.

Key takeaways from this documentation include:
- Setting up and using BERTopic for initial topic modeling.
- Understanding and extending the functionality of the `MistralRepresentation` class.
- Interacting with the Chat/Context and Generate APIs for diverse use cases.
- Managing prompts and fine-tuning parameters to optimize results.

This setup enables scalable and efficient processing of large text datasets while offering flexibility to adapt to evolving requirements. Future updates and enhancements will continue to refine the framework, ensuring it remains a powerful tool for text-based AI applications.


---
158 changes: 100 additions & 58 deletions python-notebooks/MistralRepresentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,25 @@
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document


DEFAULT_PROMPT = """
This is a list of texts where each collection of texts describe a topic. After each collection of texts, the name of the topic they represent is mentioned as a short-highly-descriptive title
---
Topic:
Sample texts from this topic:
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
- Meat, but especially beef, is the worst food in terms of emissions.
- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.
Keywords: meat beef eat eating emissions steak food health processed chicken
Topic name: Environmental impacts of eating meat
---
Topic:
Sample texts from this topic:
- I have ordered the product weeks ago but it still has not arrived!
- The website mentions that it only takes a couple of days to deliver but I still have not received mine.
- I got a message stating that I received the monitor but that is not true!
- It took a month longer to deliver than was advised...
Keywords: deliver weeks product shipping long delivery received arrived arrive week
Topic name: Shipping and delivery issues
---
Topic:
Sample texts from this topic:
[DOCUMENTS]
Keywords: [KEYWORDS]
Topic name:"""

DEFAULT_CHAT_PROMPT = """
I have a topic that contains the following documents:
here are documents:
[DOCUMENTS]
The topic is described by the following keywords: [KEYWORDS]
I need you to write "The topic is:" then print a short description of the documents in markdown format.
"""

Based on the information above, extract a short topic label in the following format:
topic: <topic label>
DEFAULT_PROMPT_CHAT_START = """
I will send you documents. After I send you the documents, I will ask you to write a short description of what is going on.
"""

DEFAULT_PROMPT = """
here are documents:
DEFAULT_PROMPT_CHAT_CONTEXT = """
This is one of the documents:
[DOCUMENTS]
The topic is described by the following keywords: [KEYWORDS]
I need you to write "The topic is:" then print a short description of the documents in markdown format.
"""

DEFAULT_PROMPT = DEFAULT_PROMPT
DEFAULT_PROMPT_CHAT_ENDING = """
I need you to write "The topic is:" then print a short description of the documents in markdown format.
"""

class MistralRepresentation(BaseRepresentation):
def __init__(
Expand All @@ -66,24 +39,25 @@ def __init__(
diversity: float = None,
doc_length: int = None,
tokenizer: Union[str, Callable] = None,
api: str = "generate",
):
self.model = model
self.api = api
if prompt is None:
self.prompt = DEFAULT_CHAT_PROMPT if chat else DEFAULT_PROMPT
self.prompt = DEFAULT_PROMPT if chat else DEFAULT_PROMPT
else:
self.prompt = prompt

self.default_prompt_ = DEFAULT_CHAT_PROMPT if chat else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.delay_in_seconds = delay_in_seconds
self.exponential_backoff = exponential_backoff
self.chat = chat
self.chat = True if self.api == "chat" else False
self.nr_docs = nr_docs
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer

self.prompts_ = []

self.chat_messages = []
self.generator_kwargs = generator_kwargs
if self.generator_kwargs.get("model"):
self.model = generator_kwargs.get("model")
Expand All @@ -92,6 +66,51 @@ def __init__(
del self.generator_kwargs["prompt"]
if not self.generator_kwargs.get("stop") and not chat:
self.generator_kwargs["stop"] = "\n"

def stream_response(self, url, payload) -> str:
"""
Stream responses from a POST request to a specified URL.
Arguments:
url (str): The endpoint URL to which the POST request will be sent.
payload (dict): A dictionary containing the data to be sent in the POST request.
Typically includes the model and input data required for processing.
Returns:
str: The concatenated response content from the server. This could include
either 'message' content (for api/generate) or 'response' content
(for api/chat), depending on the API configuration.
"""
response_text = ""
with requests.post(url, json=payload, stream=False) as response:
if response.status_code == 200:
for line in response.iter_lines():
if line:
try:
# Parse JSON and extract the 'response' field
data = json.loads(line.decode("utf-8"))
if "message" in data:
print(
data["message"]["content"], end=""
) # print when api/generate is used

response_text += data["message"]["content"]

self.chat_messages.append(data["message"])
if "response" in data:
print(
data["response"], end=""
) # print when api/chat is used

response_text += data["response"]

except json.JSONDecodeError:
print(f"Failed to decode JSON: {line}")
else:
print(
f"Failed to retrieve response. Status Code: {response.status_code} Response: {response.text}"
)
return response_text

def extract_topics(
self,
Expand All @@ -111,23 +130,52 @@ def extract_topics(
Returns:
updated_topics: Updated topic representations
"""
model = "mistral-small"
url = "http://127.0.0.1:11434/api/" + self.api

# Extract the top n representative documents per topic
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
)

# Generate using Mistral's Language Model
updated_topics = {}
if self.api == "chat":
self.chat_messages = []
self.chat_messages.append({"role": "user", "content": DEFAULT_PROMPT_CHAT_START})
payload = {
"model": model,
"messages": self.chat_messages,
}
response = stream_response(url, payload)

for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
prompt = self._create_prompt(truncated_docs, topic, topics)
self.prompts_.append(prompt)

model = "mistral-small"
url = "http://127.0.0.1:11434/api/generate"

# label = get_response(model, url, prompt, messages=prompt, generate=True)
response = get_response(model, url, prompt, messages=prompt, generate=True)
if self.api == "chat":
context_prompt = DEFAULT_PROMPT_CHAT_CONTEXT
# make loop for each truncated_docs so feed one prompt at a time
for doc in truncated_docs:
context_prompt = self._replace_documents(context_prompt, [doc])
self.chat_messages.append({"role": "user", "content": context_prompt})
payload = {
"model": model,
"messages": self.chat_messages,
}
response = self.stream_response(url, payload)

# now ask the topic to llm
self.chat_messages.append({"role": "user", "content": DEFAULT_PROMPT_CHAT_ENDING})
payload = {
"model": model,
"messages": self.chat_messages,
}
response = self.stream_response(url, payload)

if self.api == "generate":
response = get_response(model, url, prompt, messages=prompt)

# Extract the topic name from the response
topic_name = self._extract_topic_name(response)
Expand Down Expand Up @@ -157,19 +205,13 @@ def _extract_topic_name(self, response: str) -> str:
def _create_prompt(self, docs, topic, topics):
keywords = list(zip(*topics[topic]))[0]

# Use the Default Chat Prompt
if self.prompt == DEFAULT_CHAT_PROMPT or self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
prompt = self._replace_documents(prompt, docs)

# Use a custom prompt that leverages keywords, documents or both using
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs)
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs)

return prompt

Expand Down
Loading

0 comments on commit c7a7841

Please sign in to comment.