Skip to content

Commit

Permalink
chore: Moving Milvus client to PyMilvus (#4907)
Browse files Browse the repository at this point in the history
* chore: Moving Milvus client to PyMilvus

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* linted and switched implementation to pymilvus

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* adding updates for integration configuration

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* removing drop statement

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
  • Loading branch information
franciscojavierarceo authored Jan 8, 2025
1 parent 76e1e21 commit 5f9b5b5
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 89 deletions.
2 changes: 1 addition & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1757,7 +1757,7 @@ def retrieve_online_documents(
query: Union[str, List[float]],
top_k: int,
features: Optional[List[str]] = None,
distance_metric: Optional[str] = None,
distance_metric: Optional[str] = "L2",
) -> OnlineResponse:
"""
Retrieves the top k closest document features. Note, embeddings are a subset of features.
Expand Down
147 changes: 84 additions & 63 deletions sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
CollectionSchema,
DataType,
FieldSchema,
connections,
MilvusClient,
)
from pymilvus.orm.connections import Connections

from feast import Entity
from feast.feature_view import FeatureView
Expand Down Expand Up @@ -85,14 +84,15 @@ class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
"""

type: Literal["milvus"] = "milvus"

host: Optional[StrictStr] = "localhost"
port: Optional[int] = 19530
index_type: Optional[str] = "IVF_FLAT"
metric_type: Optional[str] = "L2"
embedding_dim: Optional[int] = 128
vector_enabled: Optional[bool] = True
nlist: Optional[int] = 128
username: Optional[StrictStr] = ""
password: Optional[StrictStr] = ""


class MilvusOnlineStore(OnlineStore):
Expand All @@ -103,24 +103,23 @@ class MilvusOnlineStore(OnlineStore):
_collections: Dictionary to cache Milvus collections.
"""

_conn: Optional[Connections] = None
_collections: Dict[str, Collection] = {}
client: Optional[MilvusClient] = None
_collections: Dict[str, Any] = {}

def _connect(self, config: RepoConfig) -> connections:
if not self._conn:
if not connections.has_connection("feast"):
self._conn = connections.connect(
alias="feast",
host=config.online_store.host,
port=str(config.online_store.port),
)
return self._conn
def _connect(self, config: RepoConfig) -> MilvusClient:
if not self.client:
self.client = MilvusClient(
url=f"{config.online_store.host}:{config.online_store.port}",
token=f"{config.online_store.username}:{config.online_store.password}"
if config.online_store.username and config.online_store.password
else "",
)
return self.client

def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection:
def _get_collection(self, config: RepoConfig, table: FeatureView) -> Dict[str, Any]:
self.client = self._connect(config)
collection_name = _table_id(config.project, table)
if collection_name not in self._collections:
self._connect(config)

# Create a composite key by combining entity fields
composite_key_name = (
"_".join([field.name for field in table.entity_columns]) + "_pk"
Expand Down Expand Up @@ -166,23 +165,38 @@ def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection:
schema = CollectionSchema(
fields=fields, description="Feast feature view data"
)
collection = Collection(name=collection_name, schema=schema, using="feast")
if not collection.has_index():
index_params = {
"index_type": config.online_store.index_type,
"metric_type": config.online_store.metric_type,
"params": {"nlist": config.online_store.nlist},
}
for vector_field in schema.fields:
if vector_field.dtype in [
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
]:
collection.create_index(
field_name=vector_field.name, index_params=index_params
)
collection.load()
self._collections[collection_name] = collection
collection_exists = self.client.has_collection(
collection_name=collection_name
)
if not collection_exists:
self.client.create_collection(
collection_name=collection_name,
dimension=config.online_store.embedding_dim,
schema=schema,
)
index_params = self.client.prepare_index_params()
for vector_field in schema.fields:
if vector_field.dtype in [
DataType.FLOAT_VECTOR,
DataType.BINARY_VECTOR,
]:
index_params.add_index(
collection_name=collection_name,
field_name=vector_field.name,
metric_type=config.online_store.metric_type,
index_type=config.online_store.index_type,
index_name=f"vector_index_{vector_field.name}",
params={"nlist": config.online_store.nlist},
)
self.client.create_index(
collection_name=collection_name,
index_params=index_params,
)
else:
self.client.load_collection(collection_name)
self._collections[collection_name] = self.client.describe_collection(
collection_name
)
return self._collections[collection_name]

def online_write_batch(
Expand All @@ -199,6 +213,7 @@ def online_write_batch(
],
progress: Optional[Callable[[int], Any]],
) -> None:
self.client = self._connect(config)
collection = self._get_collection(config, table)
entity_batch_to_insert = []
for entity_key, values_dict, timestamp, created_ts in data:
Expand Down Expand Up @@ -231,8 +246,9 @@ def online_write_batch(
if progress:
progress(1)

collection.insert(entity_batch_to_insert)
collection.flush()
self.client.insert(
collection_name=collection["collection_name"], data=entity_batch_to_insert
)

def online_read(
self,
Expand All @@ -252,14 +268,14 @@ def update(
entities_to_keep: Sequence[Entity],
partial: bool,
):
self._connect(config)
self.client = self._connect(config)
for table in tables_to_keep:
self._get_collection(config, table)
self._collections = self._get_collection(config, table)

for table in tables_to_delete:
collection_name = _table_id(config.project, table)
collection = Collection(name=collection_name)
if collection.exists():
collection.drop()
if self._collections.get(collection_name, None):
self.client.drop_collection(collection_name)
self._collections.pop(collection_name, None)

def plan(
Expand All @@ -273,12 +289,12 @@ def teardown(
tables: Sequence[FeatureView],
entities: Sequence[Entity],
):
self._connect(config)
self.client = self._connect(config)
for table in tables:
collection = self._get_collection(config, table)
if collection:
collection.drop()
self._collections.pop(collection.name, None)
collection_name = _table_id(config.project, table)
if self._collections.get(collection_name, None):
self.client.drop_collection(collection_name)
self._collections.pop(collection_name, None)

def retrieve_online_documents(
self,
Expand All @@ -298,6 +314,8 @@ def retrieve_online_documents(
Optional[ValueProto],
]
]:
self.client = self._connect(config)
collection_name = _table_id(config.project, table)
collection = self._get_collection(config, table)
if not config.online_store.vector_enabled:
raise ValueError("Vector search is not enabled in the online store config")
Expand All @@ -321,42 +339,45 @@ def retrieve_online_documents(
+ ["created_ts", "event_ts"]
)
assert all(
field
field in [f["name"] for f in collection["fields"]]
for field in output_fields
if field in [f.name for f in collection.schema.fields]
), f"field(s) [{[field for field in output_fields if field not in [f.name for f in collection.schema.fields]]}'] not found in collection schema"

), f"field(s) [{[field for field in output_fields if field not in [f['name'] for f in collection['fields']]]}] not found in collection schema"
# Note we choose the first vector field as the field to search on. Not ideal but it's something.
ann_search_field = None
for field in collection.schema.fields:
for field in collection["fields"]:
if (
field.dtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]
and field.name in output_fields
field["type"] in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR]
and field["name"] in output_fields
):
ann_search_field = field.name
ann_search_field = field["name"]
break

results = collection.search(
self.client.load_collection(collection_name)
results = self.client.search(
collection_name=collection_name,
data=[embedding],
anns_field=ann_search_field,
param=search_params,
search_params=search_params,
limit=top_k,
output_fields=output_fields,
consistency_level="Strong",
)

result_list = []
for hits in results:
for hit in hits:
single_record = {}
for field in output_fields:
single_record[field] = hit.entity.get(field)
single_record[field] = hit.get("entity", {}).get(field, None)

entity_key_bytes = bytes.fromhex(hit.entity.get(composite_key_name))
embedding = hit.entity.get(ann_search_field)
entity_key_bytes = bytes.fromhex(
hit.get("entity", {}).get(composite_key_name, None)
)
embedding = hit.get("entity", {}).get(ann_search_field)
serialized_embedding = _serialize_vector_to_float_list(embedding)
distance = hit.distance
event_ts = datetime.fromtimestamp(hit.entity.get("event_ts") / 1e6)
distance = hit.get("distance", None)
event_ts = datetime.fromtimestamp(
hit.get("entity", {}).get("event_ts") / 1e6
)
prepared_result = _build_retrieve_online_document_record(
entity_key_bytes,
# This may have a bug
Expand Down Expand Up @@ -412,7 +433,7 @@ def __init__(self, host: str, port: int, name: str):
self._connect()

def _connect(self):
return connections.connect(alias="default", host=self.host, port=str(self.port))
raise NotImplementedError

def to_infra_object_proto(self) -> InfraObjectProto:
# Implement serialization if needed
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Dict

from testcontainers.milvus import MilvusContainer
import docker
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs

from tests.integration.feature_repos.universal.online_store_creator import (
OnlineStoreCreator,
Expand All @@ -11,13 +13,19 @@ class MilvusOnlineStoreCreator(OnlineStoreCreator):
def __init__(self, project_name: str, **kwargs):
super().__init__(project_name)
self.fixed_port = 19530
self.container = MilvusContainer("milvusdb/milvus:v2.4.4").with_exposed_ports(
self.container = DockerContainer("milvusdb/milvus:v2.4.4").with_exposed_ports(
self.fixed_port
)
self.client = docker.from_env()

def create_online_store(self) -> Dict[str, Any]:
self.container.start()
# Wait for Milvus server to be ready
# log_string_to_wait_for = "Ready to accept connections"
log_string_to_wait_for = ""
wait_for_logs(
container=self.container, predicate=log_string_to_wait_for, timeout=30
)
host = "localhost"
port = self.container.get_exposed_port(self.fixed_port)
return {
Expand Down
46 changes: 23 additions & 23 deletions sdk/python/tests/integration/online_store/test_universal_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,26 +897,26 @@ def test_retrieve_online_documents(environment, fake_document_data):
).to_dict()


# @pytest.mark.integration
# @pytest.mark.universal_online_stores(only=["milvus"])
# def test_retrieve_online_milvus_documents(environment, fake_document_data):
# fs = environment.feature_store
# df, data_source = fake_document_data
# item_embeddings_feature_view = create_item_embeddings_feature_view(data_source)
# fs.apply([item_embeddings_feature_view, item()])
# fs.write_to_online_store("item_embeddings", df)
# documents = fs.retrieve_online_documents(
# feature=None,
# features=[
# "item_embeddings:embedding_float",
# "item_embeddings:item_id",
# "item_embeddings:string_feature",
# ],
# query=[1.0, 2.0],
# top_k=2,
# distance_metric="L2",
# ).to_dict()
# assert len(documents["embedding_float"]) == 2
#
# assert len(documents["item_id"]) == 2
# assert documents["item_id"] == [2, 3]
@pytest.mark.integration
@pytest.mark.universal_online_stores(only=["milvus"])
def test_retrieve_online_milvus_documents(environment, fake_document_data):
fs = environment.feature_store
df, data_source = fake_document_data
item_embeddings_feature_view = create_item_embeddings_feature_view(data_source)
fs.apply([item_embeddings_feature_view, item()])
fs.write_to_online_store("item_embeddings", df)
documents = fs.retrieve_online_documents(
feature=None,
features=[
"item_embeddings:embedding_float",
"item_embeddings:item_id",
"item_embeddings:string_feature",
],
query=[1.0, 2.0],
top_k=2,
distance_metric="L2",
).to_dict()
assert len(documents["embedding_float"]) == 2

assert len(documents["item_id"]) == 2
assert documents["item_id"] == [2, 3]

0 comments on commit 5f9b5b5

Please sign in to comment.