-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathopen_ai_chunk_size_and_k.py
65 lines (51 loc) · 2.01 KB
/
open_ai_chunk_size_and_k.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
from langchain_astradb import AstraDBVectorStore
from langchain_community.document_loaders import UnstructuredFileLoader
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
EMBEDDING_MODEL = "text-embedding-3-small"
LLM_MODEL = "gpt-3.5-turbo"
def get_vector_store(chunk_size: int):
return AstraDBVectorStore(
embedding=OpenAIEmbeddings(model=EMBEDDING_MODEL),
collection_name=f"chunk_size_{chunk_size}",
token=os.getenv("ASTRA_DB_TOKEN"),
api_endpoint=os.getenv("ASTRA_DB_ENDPOINT"),
)
def ingest(file_path: str, chunk_size: int, **kwargs):
vector_store = get_vector_store(chunk_size=chunk_size)
chunk_overlap = min(chunk_size / 4, min(chunk_size / 2, 64))
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
model_name=EMBEDDING_MODEL,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
docs = UnstructuredFileLoader(
file_path=file_path, mode="single", strategy="fast"
).load()
split_docs = text_splitter.split_documents(docs)
vector_store.add_documents(split_docs)
def query_pipeline(k: int, chunk_size: int, **kwargs):
vector_store = get_vector_store(chunk_size=chunk_size)
llm = ChatOpenAI(model_name=LLM_MODEL)
# build a prompt
prompt_template = """
Answer the question based only on the supplied context. If you don't know the answer, say: "I don't know".
Context: {context}
Question: {question}
Your answer:
"""
prompt = ChatPromptTemplate.from_template(prompt_template)
rag_chain = (
{
"context": vector_store.as_retriever(search_kwargs={"k": k}),
"question": RunnablePassthrough(),
}
| prompt
| llm
| StrOutputParser()
)
return rag_chain