Skip to content

Commit

Permalink
chore: add oci and s3 helper methods
Browse files Browse the repository at this point in the history
Signed-off-by: Eric Dobroveanu <edobrove@redhat.com>
  • Loading branch information
Crazyglue committed Feb 5, 2025
1 parent 8090a03 commit 409964d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
6 changes: 4 additions & 2 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import Any, TypeVar, Union, get_args
from warnings import warn

from .utils import is_oci_uri, is_s3_uri

from .core import ModelRegistryAPIClient
from .exceptions import StoreError
from .types import (
Expand Down Expand Up @@ -203,9 +205,9 @@ def upload_artifact_and_register_model(
metadata: Mapping[str, SupportedTypes] | None = None,
upload_client_params: Mapping[str, str] | None = None,
) -> RegisteredModel:
if destination_uri.startswith("s3://"):
if is_s3_uri(destination_uri):
self._upload_to_s3(artifact_local_path, destination_uri, upload_client_params['region_name'])
elif destination_uri.startswith("oci://"):
elif is_oci_uri(destination_uri):
self._upload_to_oci(artifact_local_path, destination_uri)
else:
msg = "Invalid destination URI. Must start with 's3://' or 'oci://'"
Expand Down
44 changes: 44 additions & 0 deletions clients/python/src/model_registry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import os
import re

from typing_extensions import overload

Expand Down Expand Up @@ -90,3 +91,46 @@ def s3_uri_from(
# https://alexwlchan.net/2020/s3-keys-are-not-file-paths/ nor do they resolve to valid URls
# FIXME: is this safe?
return f"s3://{bucket}/{path}?endpoint={endpoint}&defaultRegion={region}"

s3_prefix = "s3://"

def is_s3_uri(uri: str):
"""Checks whether a string is a valid S3 URI
This helper function checks whether the string starts with the correct s3 prefix (s3://) and
whether the string contains both a bucket and a key.
Args:
uri: The URI to check
Returns:
Boolean indicating whether it is a valid S3 URI
"""
if not uri.startswith(s3_prefix):
return False
# Slice the uri from prefix onward, then check if there are 2 components when splitting on "/"
path = uri[len(s3_prefix) :]
if len(path.split("/", 1)) != 2:
return False
return True

oci_pattern = r'^oci://(?P<host>[^/]+)/(?P<repository>[A-Za-z0-9_\-/]+)(:(?P<tag>[A-Za-z0-9_.-]+))?$'

def is_oci_uri(uri: str):
"""Checks whether a string is a valid OCI URI
The expected format is:
oci://<host>/<repository>[:<tag>]
Examples of valid URIs:
oci://registry.example.com/my-namespace/my-repo:latest
oci://localhost:5000/my-repo
Args:
uri: The URI to check
Returns:
Boolean indicating whether it is a valid OCI URI
"""
return re.match(oci_pattern, uri) is not None

0 comments on commit 409964d

Please sign in to comment.