Skip to content

Commit

Permalink
Fix for linter issue
Browse files Browse the repository at this point in the history
  • Loading branch information
treff7es committed Dec 2, 2024
1 parent 1cfdcf1 commit 9225a6c
Showing 1 changed file with 25 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,27 @@ def strip_quotes(string: str) -> str:
model_metrics,
)

@staticmethod
def get_group_name_from_arn(arn: str) -> str:
"""
Extract model package group name from a SageMaker ARN.
Args:
arn (str): Full ARN of the model package group
Returns:
str: Name of the model package group
Example:
>>> arn = "arn:aws:sagemaker:eu-west-1:123456789:model-package-group/my-model-group"
>>> get_group_name_from_arn(arn)
"my-model-group"
"""
logger.info(
f"Extracting group name from ARN: {arn} because group was not seen before"
)
return arn.split("/")[-1]

def get_model_wu(
self,
model_details: "DescribeModelOutputTypeDef",
Expand Down Expand Up @@ -427,16 +448,12 @@ def get_model_wu(

model_group_arns = model_uri_groups | model_image_groups

# Filter, sort the model group names, and log missing keys in one shot
model_group_names = sorted(
[
self.group_arn_to_name[arn]
if arn in self.group_arn_to_name
else logger.warning(
f"Model is associated with a group ARN {arn} which was not listed in the model groups"
)
or arn
for arn in model_group_arns
self.group_arn_to_name[x]
if x in self.group_arn_to_name
else self.get_group_name_from_arn(x)
for x in model_group_arns
]
)

Expand Down

0 comments on commit 9225a6c

Please sign in to comment.