diff --git a/clients/python/src/model_registry/_client.py b/clients/python/src/model_registry/_client.py index c75a333c..9768cf05 100644 --- a/clients/python/src/model_registry/_client.py +++ b/clients/python/src/model_registry/_client.py @@ -9,6 +9,8 @@ from typing import Any, TypeVar, Union, get_args from warnings import warn +from .utils import is_oci_uri, is_s3_uri + from .core import ModelRegistryAPIClient from .exceptions import StoreError from .types import ( @@ -203,9 +205,9 @@ def upload_artifact_and_register_model( metadata: Mapping[str, SupportedTypes] | None = None, upload_client_params: Mapping[str, str] | None = None, ) -> RegisteredModel: - if destination_uri.startswith("s3://"): + if is_s3_uri(destination_uri): self._upload_to_s3(artifact_local_path, destination_uri, upload_client_params['region_name']) - elif destination_uri.startswith("oci://"): + elif is_oci_uri(destination_uri): self._upload_to_oci(artifact_local_path, destination_uri) else: msg = "Invalid destination URI. Must start with 's3://' or 'oci://'" diff --git a/clients/python/src/model_registry/utils.py b/clients/python/src/model_registry/utils.py index e60dcf5d..866d5e85 100644 --- a/clients/python/src/model_registry/utils.py +++ b/clients/python/src/model_registry/utils.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +import re from typing_extensions import overload @@ -90,3 +91,46 @@ def s3_uri_from( # https://alexwlchan.net/2020/s3-keys-are-not-file-paths/ nor do they resolve to valid URls # FIXME: is this safe? return f"s3://{bucket}/{path}?endpoint={endpoint}&defaultRegion={region}" + +s3_prefix = "s3://" + +def is_s3_uri(uri: str): + """Checks whether a string is a valid S3 URI + + This helper function checks whether the string starts with the correct s3 prefix (s3://) and + whether the string contains both a bucket and a key. + + Args: + uri: The URI to check + + Returns: + Boolean indicating whether it is a valid S3 URI + """ + if not uri.startswith(s3_prefix): + return False + # Slice the uri from prefix onward, then check if there are 2 components when splitting on "/" + path = uri[len(s3_prefix) :] + if len(path.split("/", 1)) != 2: + return False + return True + +oci_pattern = r'^oci://(?P[^/]+)/(?P[A-Za-z0-9_\-/]+)(:(?P[A-Za-z0-9_.-]+))?$' + +def is_oci_uri(uri: str): + """Checks whether a string is a valid OCI URI + + The expected format is: + oci:///[:] + + Examples of valid URIs: + oci://registry.example.com/my-namespace/my-repo:latest + oci://localhost:5000/my-repo + + Args: + uri: The URI to check + + Returns: + Boolean indicating whether it is a valid OCI URI + """ + return re.match(oci_pattern, uri) is not None + \ No newline at end of file