Skip to content

Commit

Permalink
Merge pull request #9 from gizatechxyz/deployment-uri
Browse files Browse the repository at this point in the history
Run Cairo program from deployment URI
  • Loading branch information
raphaelDkhn authored Jan 26, 2024
2 parents 9063356 + c8e16ba commit 329ec68
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 16 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ dmypy.json
# Pyre type checker
.pyre/

**/data
**/*.onnx
**/*.txt
**/*.jpg
Expand Down
29 changes: 16 additions & 13 deletions giza_actions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from giza.utils.enums import VersionStatus
from osiris.app import create_tensor_from_array, deserialize, serialize, serializer

from giza_actions.utils import get_deployment_uri


class GizaModel:
def __init__(
Expand All @@ -18,37 +20,40 @@ def __init__(
id: Optional[int] = None,
version: Optional[int] = None,
output_path: Optional[str] = None,
orion_runner_service_url: Optional[str] = None,
):
if model_path is None and id is None and version is None:
raise ValueError("Either model_path or id and version must be provided.")
raise ValueError(
"Either model_path or id and version must be provided.")

if model_path is None and (id is None or version is None):
raise ValueError("Both id and version must be provided.")

if model_path and (id or version):
raise ValueError("Either model_path or id and version must be provided.")

self.orion_runner_service_url = orion_runner_service_url
raise ValueError(
"Either model_path or id and version must be provided.")

if model_path:
self.session = ort.InferenceSession(model_path)
elif id and version:
self.model_client = ModelsClient(API_HOST)
self.version_client = VersionsClient(API_HOST)
self.api_client = ApiClient(API_HOST)
self.uri = get_deployment_uri(id, version)
self._get_credentials()
self._download_model(id, version, output_path)
self.session = None
if output_path:
self._download_model(id, version, output_path)

def _download_model(self, model_id: int, version_id: int, output_path: str):
version = self.version_client.get(model_id, version_id)

if version.status != VersionStatus.COMPLETED:
raise ValueError(f"Model version status is not completed {version.status}")
raise ValueError(
f"Model version status is not completed {version.status}")

print("ONNX model is ready, downloading! ✅")
onnx_model = self.api_client.download_original(model_id, version.version)
onnx_model = self.api_client.download_original(
model_id, version.version)

model_name = version.original_model_path.split("/")[-1]
save_path = Path(output_path) / model_name
Expand All @@ -73,10 +78,10 @@ def predict(
output_dtype: str = "tensor_fixed_point",
):
if verifiable:
if not self.orion_runner_service_url:
raise ValueError("Orion Runner service URL must be provided")
if not self.uri:
raise ValueError("Model has not been deployed")

endpoint = f"{self.orion_runner_service_url}/cairo_run"
endpoint = f"{self.uri}/cairo_run"

cairo_payload = self._format_inputs_for_cairo(
input_file, input_feed, fp_impl
Expand Down Expand Up @@ -107,8 +112,6 @@ def _format_inputs_for_cairo(
serialized = None

if input_file is not None:
print(input_file)

serialized = serialize(input_file, fp_impl)

if input_feed is not None:
Expand Down
28 changes: 27 additions & 1 deletion giza_actions/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from giza import API_HOST
from giza.client import WorkspaceClient
from giza.client import WorkspaceClient, DeploymentsClient


def get_workspace_uri():
Expand All @@ -16,3 +16,29 @@ def get_workspace_uri():
client = WorkspaceClient(API_HOST)
workspace = client.get()
return workspace.url


def get_deployment_uri(model_id: int, version_id: int):
"""
Get the deployment URI associated with a specific model and version.
Args:
model_id (int): The ID of the model.
version_id (int): The ID of the version.
This function initializes a DeploymentsClient instance using the API_HOST and
retrieves the deployment URI using its list method. The resulting URL of the
deployment is returned.
Returns:
str: The URI of the deployment.
"""
client = DeploymentsClient(API_HOST)
deployments_list = client.list(model_id, version_id)

deployments = deployments_list.__root__

if deployments:
return deployments[0].uri
else:
return None
6 changes: 4 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# TODO: Implement a test env.

import numpy as np

from giza_actions.model import GizaModel


def test_predict_success():
model = GizaModel(model_path="", orion_runner_service_url="http://localhost:8080")
model = GizaModel(id=50, version=2)

arr = np.array([[1, 2], [3, 4]], dtype=np.uint32)

Expand All @@ -16,7 +18,7 @@ def test_predict_success():


def test_predict_success_with_file():
model = GizaModel(model_path="", orion_runner_service_url="http://localhost:8080")
model = GizaModel(id=50, version=2)

expected = np.array([[1, 2], [3, 4]], dtype=np.uint32)

Expand Down

0 comments on commit 329ec68

Please sign in to comment.