-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathrag_langchain_demo.py
101 lines (77 loc) · 3.28 KB
/
rag_langchain_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Langchain integration demo"""
import os
import tempfile
import streamlit as st
from langchain import LLMChain, PromptTemplate
from langchain.chains import RetrievalQA
from langchain.document_loaders import PyPDFLoader
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from langchain.retrievers import EnsembleRetriever
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from retrievals.tools.langchain import LangchainEmbedding, LangchainLLM
st.set_page_config(page_title="RAG with Open-retrievals")
with st.sidebar:
st.write("**RAG with Open-retrievals**")
@st.cache_resource(ttl="1h")
def configure_retriever(uploaded_files):
# read documents
docs = []
temp_dir = tempfile.TemporaryDirectory()
for file in uploaded_files:
temp_filepath = os.path.join(temp_dir.name, file.name)
with open(temp_filepath, "wb") as f:
f.write(file.getvalue())
loader = PyPDFLoader(temp_filepath)
docs.extend(loader.load())
# split documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
splits = text_splitter.split_documents(docs)
embeddings = LangchainEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectordb = Chroma.from_documents(splits, embeddings)
retrieval_args = {"search_type": "similarity", "score_threshold": 0.15, "k": 30}
retriever = vectordb.as_retriever(**retrieval_args)
return retriever
uploaded_files = st.sidebar.file_uploader(label="Upload PDF files", type=["pdf"], accept_multiple_files=True)
if not uploaded_files:
st.info("Please upload PDF documents to continue.")
st.stop()
retriever = configure_retriever(uploaded_files)
llm = LangchainLLM(model_name_or_path="Qwen/Qwen1.5-1.8B-Chat", temperature=0.5, max_tokens=2048, top_k=10)
msgs = StreamlitChatMessageHistory()
RESPONSE_TEMPLATE = """[INST]
<>
You are a helpful AI assistant. Use the following pieces of context to answer the user's question.<>
Anything between the following `context` html blocks is retrieved from a knowledge base.
{context}
REMEMBER:
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
- Let's take a deep breath and think step-by-step.
Question: {question}[/INST]
Helpful Answer:
"""
PROMPT = PromptTemplate.from_template(RESPONSE_TEMPLATE)
PROMPT = PromptTemplate(template=RESPONSE_TEMPLATE, input_variables=["context", "question"])
qa_chain = RetrievalQA.from_chain_type(
llm,
chain_type='stuff',
retriever=retriever,
chain_type_kwargs={
"verbose": True,
"prompt": PROMPT,
},
)
if len(msgs.messages) == 0 or st.sidebar.button("New Chat"):
msgs.clear()
msgs.add_ai_message("How can I help you?")
avatars = {"human": "user", "ai": "assistant"}
for msg in msgs.messages:
st.chat_message(avatars[msg.type]).write(msg.content)
if user_query := st.chat_input(placeholder="Ask me anything!"):
st.chat_message("user").write(user_query)
with st.chat_message("assistant"):
response = qa_chain({"query": user_query})
answer = response["result"]
st.write(answer)
about = st.sidebar.expander("About")
about.write("Powered by [open-retrievals](https://github.com/LongxingTan/open-retrievals)")