Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][DO NOT MERGE] AQUA Mutli-Model Deployment #1061

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
1446ea0
Added get multi model deployment config.
lu-ohai Feb 4, 2025
c925ec4
Updated pr.
lu-ohai Feb 4, 2025
3187e2b
Merge branch 'main' of https://github.com/oracle/accelerated-data-sci…
lu-ohai Feb 4, 2025
25afd80
Updated pr.
lu-ohai Feb 4, 2025
07eb59d
inital AQUA API code changes in get_deployment_default_params
elizjo Feb 5, 2025
5cb74b3
Update to pydantic models
VipulMascarenhas Feb 5, 2025
cbd27c9
Create method to support pydantic model inputs
VipulMascarenhas Feb 5, 2025
d0d9b17
Update handler for post method
VipulMascarenhas Feb 5, 2025
a546668
Update unit tests
VipulMascarenhas Feb 5, 2025
895d2f3
Fix params input
VipulMascarenhas Feb 5, 2025
050f13f
Remove newline from err message for handler
VipulMascarenhas Feb 5, 2025
86e9c33
Updated pr.
lu-ohai Feb 5, 2025
170f7a5
simplify var init
VipulMascarenhas Feb 5, 2025
02fe829
combined multimodel and normal model unit tests for test_get_deployme…
elizjo Feb 5, 2025
4fb0923
Review comments
VipulMascarenhas Feb 5, 2025
f5c3697
[ODSC-56699] Update MD entities to pydantic models (#1059)
VipulMascarenhas Feb 5, 2025
d097ac8
Merge branch 'main' into feature/multi_model_deployment
mrDzurb Feb 5, 2025
8d51a8b
Updated pr.
lu-ohai Feb 5, 2025
c91472f
Merge branch 'feature/multi_model_deployment' of https://github.com/o…
lu-ohai Feb 5, 2025
32b73f3
Updated pr.
lu-ohai Feb 5, 2025
bc2e0b7
ODSC 68320: Modify the AQUA API to Accept GPU Count as an Optional In…
VipulMascarenhas Feb 6, 2025
1e418db
Resolve merge conflicts and ruff update
VipulMascarenhas Feb 6, 2025
0f08a64
Added API to get multi model deployment config (#1055)
VipulMascarenhas Feb 6, 2025
fbd77d5
ODSC-68526:Optimize Multi-Model Configuration Retrieval Using Paralle…
mrDzurb Feb 7, 2025
de211f2
Create multimodel catalog entry
VipulMascarenhas Feb 7, 2025
e533c14
Optimize Multi-Model Configuration Retrieval Using Parallel Execution…
lu-ohai Feb 7, 2025
05e329b
add custom metdata to multimodel
VipulMascarenhas Feb 7, 2025
a49afa2
add metadata and tests
VipulMascarenhas Feb 7, 2025
1f79762
Enhance aqua deployment handler
mrDzurb Feb 8, 2025
910f90c
Merge branch 'main' into feature/multi_model_deployment
mrDzurb Feb 10, 2025
ee3d65d
Enhances the functionality for creating multi-model entities.
mrDzurb Feb 10, 2025
5109d54
Merge branch 'feature/multi_model_deployment' into ODSC-68321/create_…
mrDzurb Feb 10, 2025
c25a3de
[ODSC-68321] Create grouped model with multiple verified models (#1064)
mrDzurb Feb 10, 2025
b9c22a3
Modify get deployment config to pydantic class.
lu-ohai Feb 10, 2025
e324467
Merge branch 'feature/multi_model_deployment' of https://github.com/o…
lu-ohai Feb 10, 2025
d0b0704
Updated pr.
lu-ohai Feb 10, 2025
bddfc41
Fixed unit tests.
lu-ohai Feb 10, 2025
63667f4
Fixed unit test.
lu-ohai Feb 10, 2025
bc3978c
Updated name.
lu-ohai Feb 10, 2025
2ea68b4
Modify get deployment config as pydantic (#1066)
mrDzurb Feb 10, 2025
8082934
Enhances error messages for loading multi-model configurations.
mrDzurb Feb 11, 2025
0edb0ae
Enhances the multi-model config retriever.
mrDzurb Feb 12, 2025
c5c6774
Enhances the multi-model configuration retrieval process. (#1067)
lu-ohai Feb 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion ads/aqua/common/entities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#!/usr/bin/env python
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from typing import Optional

from ads.aqua.config.utils.serializer import Serializable


class ContainerSpec:
"""
Expand All @@ -15,3 +19,40 @@ class ContainerSpec:
ENV_VARS = "envVars"
RESTRICTED_PARAMS = "restrictedParams"
EVALUATION_CONFIGURATION = "evaluationConfiguration"


class ShapeInfo(Serializable):
instance_shape: Optional[str] = None
instance_count: Optional[int] = None
ocpus: Optional[float] = None
memory_in_gbs: Optional[float] = None

class Config:
extra = "ignore"


class AquaMultiModelRef(Serializable):
"""
Lightweight model descriptor used for multi-model deployment.

This class only contains essential details
required to fetch complete model metadata and deploy models.

Attributes
----------
model_id : str
The unique identifier of the model.
gpu_count : Optional[int]
Number of GPUs required for deployment.
env_var : Optional[Dict[str, Any]]
Optional environment variables to override during deployment.
"""

model_id: str
model_name: Optional[str] = None
gpu_count: Optional[int] = None
env_var: Optional[dict] = None

class Config:
extra = "ignore"
protected_namespaces = ()
1 change: 1 addition & 0 deletions ads/aqua/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Tags(ExtendedEnum):
AQUA_EVALUATION_MODEL_ID = "evaluation_model_id"
MODEL_FORMAT = "model_format"
MODEL_ARTIFACT_FILE = "model_file"
MULTIMODEL_TYPE_TAG = "multimodel"


class InferenceContainerType(ExtendedEnum):
Expand Down
113 changes: 38 additions & 75 deletions ads/aqua/extension/deployment_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#!/usr/bin/env python
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from typing import List, Union
from urllib.parse import urlparse

from tornado.web import HTTPError
Expand All @@ -11,7 +12,7 @@
from ads.aqua.extension.errors import Errors
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
from ads.aqua.modeldeployment.entities import ModelParams
from ads.config import COMPARTMENT_OCID, PROJECT_OCID
from ads.config import COMPARTMENT_OCID


class AquaDeploymentHandler(AquaAPIhandler):
Expand All @@ -20,7 +21,7 @@ class AquaDeploymentHandler(AquaAPIhandler):

Methods
-------
get(self, id="")
get(self, id: Union[str, List[str]])
Retrieves a list of AQUA deployments or model info or logs by ID.
post(self, *args, **kwargs)
Creates a new AQUA deployment.
Expand All @@ -37,14 +38,15 @@ class AquaDeploymentHandler(AquaAPIhandler):
"""

@handle_exceptions
def get(self, id=""):
def get(self, id: Union[str, List[str]] = None):
"""Handle GET request."""
url_parse = urlparse(self.request.path)
paths = url_parse.path.strip("/")
if paths.startswith("aqua/deployments/config"):
if not id:
if not id or not isinstance(id, (list, str)):
raise HTTPError(
400, f"The request {self.request.path} requires model id."
400,
f"The request to {self.request.path} must include either a single model ID or a list of model IDs.",
)
return self.get_deployment_config(id)
elif paths.startswith("aqua/deployments"):
Expand Down Expand Up @@ -98,71 +100,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
if not input_data:
raise HTTPError(400, Errors.NO_INPUT_DATA)

# required input parameters
display_name = input_data.get("display_name")
if not display_name:
raise HTTPError(
400, Errors.MISSING_REQUIRED_PARAMETER.format("display_name")
)
instance_shape = input_data.get("instance_shape")
if not instance_shape:
raise HTTPError(
400, Errors.MISSING_REQUIRED_PARAMETER.format("instance_shape")
)
model_id = input_data.get("model_id")
if not model_id:
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id"))

compartment_id = input_data.get("compartment_id", COMPARTMENT_OCID)
project_id = input_data.get("project_id", PROJECT_OCID)
log_group_id = input_data.get("log_group_id")
access_log_id = input_data.get("access_log_id")
predict_log_id = input_data.get("predict_log_id")
description = input_data.get("description")
instance_count = input_data.get("instance_count")
bandwidth_mbps = input_data.get("bandwidth_mbps")
web_concurrency = input_data.get("web_concurrency")
server_port = input_data.get("server_port")
health_check_port = input_data.get("health_check_port")
env_var = input_data.get("env_var")
container_family = input_data.get("container_family")
ocpus = input_data.get("ocpus")
memory_in_gbs = input_data.get("memory_in_gbs")
model_file = input_data.get("model_file")
private_endpoint_id = input_data.get("private_endpoint_id")
container_image_uri = input_data.get("container_image_uri")
cmd_var = input_data.get("cmd_var")
freeform_tags = input_data.get("freeform_tags")
defined_tags = input_data.get("defined_tags")

self.finish(
AquaDeploymentApp().create(
compartment_id=compartment_id,
project_id=project_id,
model_id=model_id,
display_name=display_name,
description=description,
instance_count=instance_count,
instance_shape=instance_shape,
log_group_id=log_group_id,
access_log_id=access_log_id,
predict_log_id=predict_log_id,
bandwidth_mbps=bandwidth_mbps,
web_concurrency=web_concurrency,
server_port=server_port,
health_check_port=health_check_port,
env_var=env_var,
container_family=container_family,
ocpus=ocpus,
memory_in_gbs=memory_in_gbs,
model_file=model_file,
private_endpoint_id=private_endpoint_id,
container_image_uri=container_image_uri,
cmd_var=cmd_var,
freeform_tags=freeform_tags,
defined_tags=defined_tags,
)
)
self.finish(AquaDeploymentApp().create(**input_data))

def read(self, id):
"""Read the information of an Aqua model deployment."""
Expand All @@ -181,9 +119,33 @@ def list(self):
)
)

def get_deployment_config(self, model_id):
"""Gets the deployment config for Aqua model."""
return self.finish(AquaDeploymentApp().get_deployment_config(model_id=model_id))
def get_deployment_config(self, model_id: Union[str, List[str]]):
"""
Retrieves the deployment configuration for one or more Aqua models.

Parameters
----------
model_id : Union[str, List[str]]
A single model ID (str) or a list of model IDs (List[str]).

Returns
-------
None
The function sends the deployment configuration as a response.
"""
app = AquaDeploymentApp()

if isinstance(model_id, list):
# Handle multiple model deployment
primary_model_id = self.get_argument("primary_model_id", default=None)
deployment_config = app.get_multimodel_deployment_config(
model_ids=model_id, primary_model_id=primary_model_id
)
else:
# Handle single model deployment
deployment_config = app.get_deployment_config(model_id=model_id)

return self.finish(deployment_config)


class AquaDeploymentInferenceHandler(AquaAPIhandler):
Expand Down Expand Up @@ -259,9 +221,10 @@ class AquaDeploymentParamsHandler(AquaAPIhandler):
def get(self, model_id):
"""Handle GET request."""
instance_shape = self.get_argument("instance_shape")
gpu_count = self.get_argument("gpu_count", default=None)
return self.finish(
AquaDeploymentApp().get_deployment_default_params(
model_id=model_id, instance_shape=instance_shape
model_id=model_id, instance_shape=instance_shape, gpu_count=gpu_count
)
)

Expand Down
2 changes: 2 additions & 0 deletions ads/aqua/model/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ModelCustomMetadataFields(ExtendedEnum):
EVALUATION_CONTAINER = "evaluation-container"
FINETUNE_CONTAINER = "finetune-container"
DEPLOYMENT_CONTAINER_URI = "deployment-container-uri"
MULTIMODEL_GROUP_COUNT = "model_group_count"


class ModelTask(ExtendedEnum):
Expand All @@ -34,6 +35,7 @@ class FineTuningMetricCategories(ExtendedEnum):
class ModelType(ExtendedEnum):
FT = "FT" # Fine Tuned Model
BASE = "BASE" # Base model
MULTIMODEL = "MULTIMODEL"


# TODO: merge metadata key used in create FT
Expand Down
Loading
Loading