Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cache for embeddings #711

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions src/vanna/ZhipuAI/ZhipuAI_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List
from zhipuai import ZhipuAI

from chromadb import Documents, EmbeddingFunction, Embeddings
from zhipuai import ZhipuAI

from ..base import VannaBase


class ZhipuAI_Embeddings(VannaBase):
"""
[future functionality] This function is used to generate embeddings from ZhipuAI.
Expand All @@ -16,41 +19,45 @@ def __init__(self, config=None):
raise Exception("Missing api_key in config")
self.api_key = config["api_key"]
self.client = ZhipuAI(api_key=self.api_key)
self.embedding_cache = {
'question': None,
'embedding': None
}

def generate_embedding(self, data: str, **kwargs) -> List[float]:

embedding = self.client.embeddings.create(
model="embedding-2",
input=data,
)

return embedding.data[0].embedding



if self.embedding_cache['question'] == data:
return self.embedding_cache['embedding'].data[0].embedding
else:
embedding = self.client.embeddings.create(
model="embedding-2",
input=data,
)
self.embedding_cache['question'] = data
self.embedding_cache['embedding'] = embedding
return embedding.data[0].embedding
class ZhipuAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""
A embeddingFunction that uses ZhipuAI to generate embeddings which can use in chromadb.
usage:
usage:
class MyVanna(ChromaDB_VectorStore, ZhipuAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
ZhipuAI_Chat.__init__(self, config=config)

config={'api_key': 'xxx'}
zhipu_embedding_function = ZhipuAIEmbeddingFunction(config=config)
config = {"api_key": "xxx", "model": "glm-4","path":"xy","embedding_function":zhipu_embedding_function}

vn = MyVanna(config)

"""
def __init__(self, config=None):
if config is None or "api_key" not in config:
raise ValueError("Missing 'api_key' in config")

self.api_key = config["api_key"]
self.model_name = config.get("model_name", "embedding-2")

try:
self.client = ZhipuAI(api_key=self.api_key)
except Exception as e:
Expand All @@ -76,4 +83,4 @@ def __call__(self, input: Documents) -> Embeddings:
except Exception as e:
raise ValueError(f"Error generating embedding for document: {e}")

return all_embeddings
return all_embeddings
12 changes: 11 additions & 1 deletion src/vanna/azuresearch/azuresearch_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def __init__(self, config=None):
if self.index_name not in self._get_indexes():
self._create_index()

self.embedding_cache = {
'question': None,
'embedding': None
}

def _create_index(self) -> bool:
fields = [
SearchableField(name="id", type=SearchFieldDataType.String, key=True, filterable=True),
Expand Down Expand Up @@ -232,5 +237,10 @@ def remove_index(self):

def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding_model = TextEmbedding(model_name=self.fastembed_model)
embedding = next(embedding_model.embed(data))
if self.embedding_cache['question'] == data:
return self.embedding_cache['embedding'].tolist()
else:
embedding = next(embedding_model.embed(data))
self.embedding_cache['question'] = data
self.embedding_cache['embedding'] = embedding
return embedding.tolist()
17 changes: 13 additions & 4 deletions src/vanna/chromadb/chromadb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,21 @@ def __init__(self, config=None):
embedding_function=self.embedding_function,
metadata=collection_metadata,
)
self.embedding_cache = {
'question': None,
'embedding': None
}

def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding = self.embedding_function([data])
if len(embedding) == 1:
return embedding[0]
return embedding
if self.embedding_cache['question'] == data:
return self.embedding_cache['embedding']
else:
embedding = self.embedding_function([data])
if len(embedding) == 1:
return embedding[0]
self.embedding_cache['question'] = data
self.embedding_cache['embedding'] = embedding
return embedding

def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
question_sql_json = json.dumps(
Expand Down
42 changes: 26 additions & 16 deletions src/vanna/faiss/faiss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import json
import os
import uuid
from typing import List, Dict, Any
from typing import Any, Dict, List

import faiss
import numpy as np
Expand All @@ -10,13 +10,14 @@
from ..base import VannaBase
from ..exceptions import DependencyError


class FAISS(VannaBase):
def __init__(self, config=None):
if config is None:
config = {}

VannaBase.__init__(self, config=config)

try:
import faiss
except ImportError:
Expand All @@ -30,7 +31,7 @@ def __init__(self, config=None):
raise DependencyError(
"SentenceTransformer is not installed. Please install it with 'pip install sentence-transformers'."
)

self.path = config.get("path", ".")
self.embedding_dim = config.get('embedding_dim', 384)
self.n_results_sql = config.get('n_results_sql', config.get("n_results", 10))
Expand Down Expand Up @@ -59,6 +60,10 @@ def __init__(self, config=None):

model_name = config.get('embedding_model', 'all-MiniLM-L6-v2')
self.embedding_model = SentenceTransformer(model_name)
self.embedding_cache = {
'question': None,
'embedding': None
}

def _load_or_create_index(self, filename):
filepath = os.path.join(self.path, filename)
Expand All @@ -85,18 +90,23 @@ def _save_metadata(self, metadata, filename):
json.dump(metadata, f)

def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding = self.embedding_model.encode(data)
assert embedding.shape[0] == self.embedding_dim, \
f"Embedding dimension mismatch: expected {self.embedding_dim}, got {embedding.shape[0]}"
return embedding.tolist()
if self.embedding_cache['question'] == data:
return self.embedding_cache['embedding'].tolist()
else:
embedding = self.embedding_model.encode(data)
assert embedding.shape[0] == self.embedding_dim, \
f"Embedding dimension mismatch: expected {self.embedding_dim}, got {embedding.shape[0]}"
self.embedding_cache['question'] = data
self.embedding_cache['embedding'] = embedding
return embedding.tolist()

def _add_to_index(self, index, metadata_list, text, extra_metadata=None) -> str:
embedding = self.generate_embedding(text)
index.add(np.array([embedding], dtype=np.float32))
entry_id = str(uuid.uuid4())
metadata_list.append({"id": entry_id, **(extra_metadata or {})})
return entry_id

def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
entry_id = self._add_to_index(self.sql_index, self.sql_metadata, question + " " + sql, {"question": question, "sql": sql})
self._save_index(self.sql_index, 'sql_index.faiss')
Expand All @@ -122,7 +132,7 @@ def _get_similar(self, index, metadata_list, text, n_results) -> list:

def get_similar_question_sql(self, question: str, **kwargs) -> list:
return self._get_similar(self.sql_index, self.sql_metadata, question, self.n_results_sql)

def get_related_ddl(self, question: str, **kwargs) -> list:
return [metadata["ddl"] for metadata in self._get_similar(self.ddl_index, self.ddl_metadata, question, self.n_results_ddl)]

Expand Down Expand Up @@ -155,22 +165,22 @@ def remove_training_data(self, id: str, **kwargs) -> bool:
if embeddings:
new_index.add(np.array(embeddings, dtype=np.float32))
setattr(self, index_name.split('.')[0], new_index)

if self.curr_client == 'persistent':
self._save_index(new_index, index_name)
self._save_metadata(metadata_list, f"{index_name.split('.')[0]}_metadata.json")

return True
return False

def remove_collection(self, collection_name: str) -> bool:
if collection_name in ["sql", "ddl", "documentation"]:
setattr(self, f"{collection_name}_index", faiss.IndexFlatL2(self.embedding_dim))
setattr(self, f"{collection_name}_metadata", [])

if self.curr_client == 'persistent':
self._save_index(getattr(self, f"{collection_name}_index"), f"{collection_name}_index.faiss")
self._save_metadata([], f"{collection_name}_metadata.json")

return True
return False
return False
16 changes: 11 additions & 5 deletions src/vanna/google/bigquery_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
import os
import uuid
from typing import List, Optional
from vertexai.language_models import (
TextEmbeddingInput,
TextEmbeddingModel
)

import pandas as pd
from google.cloud import bigquery
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel

from ..base import VannaBase

Expand Down Expand Up @@ -82,6 +79,10 @@ def __init__(self, config: dict, **kwargs):
# Table does not exist, create it
self.conn.create_table(table, timeout=30) # Make an API request.
print(f"Created table {self.table_id}")
self.embedding_cache = {
'question': None,
'embedding': None
}

# Create VECTOR INDEX IF NOT EXISTS
# TODO: This requires 5000 rows before it can be created
Expand Down Expand Up @@ -192,7 +193,12 @@ def generate_storage_embedding(self, data: str, **kwargs) -> List[float]:
return result

def generate_embedding(self, data: str, **kwargs) -> List[float]:
return self.generate_storage_embedding(data, **kwargs)
if self.embedding_cache['question'] == data:
return self.embedding_cache['embedding']
else:
self.embedding_cache['question'] = data
self.embedding_cache['embedding'] = self.generate_storage_embedding(data, **kwargs)
return self.generate_storage_embedding(data, **kwargs)

def get_similar_question_sql(self, question: str, **kwargs) -> list:
df = self.fetch_similar_training_data(training_data_type="sql", question=question, n_results=self.n_results_sql)
Expand Down
11 changes: 10 additions & 1 deletion src/vanna/milvus/milvus_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def __init__(self, config=None):
else:
self.embedding_function = model.DefaultEmbeddingFunction()
self._embedding_dim = self.embedding_function.encode_documents(["foo"])[0].shape[0]
self.embedding_cache = {
'question': None,
'embedding': None
}
self._create_collections()
self.n_results = config.get("n_results", 10)

Expand All @@ -56,7 +60,12 @@ def _create_collections(self):


def generate_embedding(self, data: str, **kwargs) -> List[float]:
return self.embedding_function.encode_documents(data).tolist()
if self.embedding_cache['question'] == data:
return self.embedding_cache['embedding'].encode_documents(data).tolist()
else:
self.embedding_cache['question'] = data
self.embedding_cache['embedding'] = self.embedding_function
return self.embedding_function.encode_documents(data).tolist()


def _create_sql_collection(self, name: str):
Expand Down
29 changes: 19 additions & 10 deletions src/vanna/openai/openai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,25 @@ def __init__(self, client=None, config=None):
if "api_key" in config:
self.client.api_key = config["api_key"]

self.embedding_cache = {
'question': None,
'embedding': None
}

def generate_embedding(self, data: str, **kwargs) -> list[float]:
if self.config is not None and "engine" in self.config:
embedding = self.client.embeddings.create(
engine=self.config["engine"],
input=data,
)
if self.embedding_cache['question'] == data:
return self.embedding_cache['embedding'].get("data")[0]["embedding"]
else:
embedding = self.client.embeddings.create(
model="text-embedding-ada-002",
input=data,
)

if self.config is not None and "engine" in self.config:
embedding = self.client.embeddings.create(
engine=self.config["engine"],
input=data,
)
else:
embedding = self.client.embeddings.create(
model="text-embedding-ada-002",
input=data,
)
self.embedding_cache['question'] = data
self.embedding_cache['embedding'] = embedding
return embedding.get("data")[0]["embedding"]
20 changes: 15 additions & 5 deletions src/vanna/pinecone/pinecone_vector.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
from typing import List

from pinecone import Pinecone, PodSpec, ServerlessSpec
import pandas as pd
from fastembed import TextEmbedding
from pinecone import Pinecone, PodSpec, ServerlessSpec

from ..base import VannaBase
from ..utils import deterministic_uuid

from fastembed import TextEmbedding


class PineconeDB_VectorStore(VannaBase):
"""
Expand Down Expand Up @@ -77,6 +77,10 @@ def __init__(self, config=None):
self.serverless_spec = config.get(
"serverless_spec", ServerlessSpec(cloud="aws", region="us-west-2")
)
self.embedding_cache = {
'question': None,
'embedding': None
}
self._setup_index()

def _set_index_host(self, host: str) -> None:
Expand Down Expand Up @@ -271,5 +275,11 @@ def remove_training_data(self, id: str, **kwargs) -> bool:

def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding_model = TextEmbedding(model_name=self.fastembed_model)
embedding = next(embedding_model.embed(data))
return embedding.tolist()

if self.embedding_cache['question'] == data:
return self.embedding_cache['embedding'].tolist()
else:
embedding = next(embedding_model.embed(data))
self.embedding_cache['question'] = data
self.embedding_cache['embedding'] = embedding
return embedding.tolist()
Loading