forked from praneetdhoolia/retrieval-agent-template
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathretrieval.py
157 lines (122 loc) · 5.67 KB
/
retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""Manage the configuration of various retrievers.
This module provides functionality to create and manage retrievers for different
vector store backends, specifically Elasticsearch, Pinecone, MongoDB, and Milvus Lite.
The retrievers support filtering results by user_id to ensure data isolation between users.
"""
import os
from contextlib import contextmanager
from typing import Generator
from langchain_core.embeddings import Embeddings
from langchain_core.runnables import RunnableConfig
from langchain_core.vectorstores import VectorStoreRetriever
from retrieval_graph.configuration import Configuration, IndexConfiguration
## Encoder constructors
def make_text_encoder(model: str) -> Embeddings:
"""Connect to the configured text encoder."""
provider, model = model.split("/", maxsplit=1)
match provider:
case "openai":
from langchain_openai import OpenAIEmbeddings
return OpenAIEmbeddings(model=model)
case "cohere":
from langchain_cohere import CohereEmbeddings
return CohereEmbeddings(model=model) # type: ignore
case _:
raise ValueError(f"Unsupported embedding provider: {provider}")
## Retriever constructors
@contextmanager
def make_elastic_retriever(
configuration: IndexConfiguration, embedding_model: Embeddings
) -> Generator[VectorStoreRetriever, None, None]:
"""Configure this agent to connect to a specific elastic index."""
from langchain_elasticsearch import ElasticsearchStore
connection_options = {}
if configuration.retriever_provider == "elastic-local":
connection_options = {
"es_user": os.environ["ELASTICSEARCH_USER"],
"es_password": os.environ["ELASTICSEARCH_PASSWORD"],
}
else:
connection_options = {"es_api_key": os.environ["ELASTICSEARCH_API_KEY"]}
vstore = ElasticsearchStore(
**connection_options, # type: ignore
es_url=os.environ["ELASTICSEARCH_URL"],
index_name="langchain_index",
embedding=embedding_model,
)
search_kwargs = configuration.search_kwargs
search_filter = search_kwargs.setdefault("filter", [])
search_filter.append({"term": {"metadata.user_id": configuration.user_id}})
yield vstore.as_retriever(search_kwargs=search_kwargs)
@contextmanager
def make_pinecone_retriever(
configuration: IndexConfiguration, embedding_model: Embeddings
) -> Generator[VectorStoreRetriever, None, None]:
"""Configure this agent to connect to a specific pinecone index."""
from langchain_pinecone import PineconeVectorStore
search_kwargs = configuration.search_kwargs
search_filter = search_kwargs.setdefault("filter", {})
search_filter.update({"user_id": configuration.user_id})
vstore = PineconeVectorStore.from_existing_index(
os.environ["PINECONE_INDEX_NAME"], embedding=embedding_model
)
yield vstore.as_retriever(search_kwargs=search_kwargs)
@contextmanager
def make_mongodb_retriever(
configuration: IndexConfiguration, embedding_model: Embeddings
) -> Generator[VectorStoreRetriever, None, None]:
"""Configure this agent to connect to a specific MongoDB Atlas index & namespaces."""
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
vstore = MongoDBAtlasVectorSearch.from_connection_string(
os.environ["MONGODB_URI"],
namespace="langgraph_retrieval_agent.default",
embedding=embedding_model,
)
search_kwargs = configuration.search_kwargs
pre_filter = search_kwargs.setdefault("pre_filter", {})
pre_filter["user_id"] = {"$eq": configuration.user_id}
yield vstore.as_retriever(search_kwargs=search_kwargs)
@contextmanager
def make_milvus_retriever(
configuration: IndexConfiguration, embedding_model: Embeddings, **kwargs
) -> Generator[VectorStoreRetriever, None, None]:
"""Configure this agent to use milvus lite file based uri to store the vector index."""
from langchain_milvus.vectorstores import Milvus
milvus_uri = kwargs.get("alternate_milvus_uri") or os.environ.get("MILVUS_DB")
vstore = Milvus (
embedding_function=embedding_model,
collection_name=configuration.user_id,
connection_args={"uri": milvus_uri},
auto_id=True,
)
yield vstore.as_retriever(search_kwargs=configuration.search_kwargs)
@contextmanager
def make_retriever(
config: RunnableConfig,
**kwargs,
) -> Generator[VectorStoreRetriever, None, None]:
"""Create a retriever for the agent, based on the current configuration."""
configuration = IndexConfiguration.from_runnable_config(config)
embedding_model = make_text_encoder(configuration.embedding_model)
user_id = configuration.user_id
if not user_id:
raise ValueError("Please provide a valid user_id in the configuration.")
match configuration.retriever_provider:
case "elastic" | "elastic-local":
with make_elastic_retriever(configuration, embedding_model) as retriever:
yield retriever
case "pinecone":
with make_pinecone_retriever(configuration, embedding_model) as retriever:
yield retriever
case "mongodb":
with make_mongodb_retriever(configuration, embedding_model) as retriever:
yield retriever
case "milvus":
with make_milvus_retriever(configuration, embedding_model, **kwargs) as retriever:
yield retriever
case _:
raise ValueError(
"Unrecognized retriever_provider in configuration. "
f"Expected one of: {', '.join(Configuration.__annotations__['retriever_provider'].__args__)}\n"
f"Got: {configuration.retriever_provider}"
)