Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add image scheduler abstraction #115

Merged
merged 2 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tox.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ jobs:
steps:
- uses: actions/checkout@v4
with:
fetch-depth: '0'
lfs: 'true'
- uses: actions/setup-python@v5
with:
Expand Down
49 changes: 17 additions & 32 deletions src/aws/osml/model_runner/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 Amazon.com, Inc. or its affiliates.
# Copyright 2023-2025 Amazon.com, Inc. or its affiliates.

import logging

Expand All @@ -7,14 +7,15 @@
from aws.osml.gdal import load_gdal_dataset, set_gdal_default_configuration
from aws.osml.model_runner.api import get_image_path

from .api import ImageRequest, InvalidImageRequestException, RegionRequest
from .api import ImageRequest, RegionRequest
from .app_config import ServiceConfig
from .common import EndpointUtils, ThreadingLocalContextFilter
from .database import EndpointStatisticsTable, JobItem, JobTable, RegionRequestItem, RegionRequestTable
from .exceptions import RetryableJobException, SelfThrottledRegionException
from .image_request_handler import ImageRequestHandler
from .queue import RequestQueue
from .region_request_handler import RegionRequestHandler
from .scheduler.fifo_image_scheduler import FIFOImageScheduler
from .status import ImageStatusMonitor, RegionStatusMonitor
from .tile_worker import TilingStrategy, VariableOverlapTilingStrategy

Expand All @@ -41,9 +42,10 @@ def __init__(self, tiling_strategy: TilingStrategy = VariableOverlapTilingStrate
self.config = ServiceConfig()
self.tiling_strategy = tiling_strategy

# Set up queues and monitors
self.image_request_queue = RequestQueue(self.config.image_queue, wait_seconds=0)
self.image_requests_iter = iter(self.image_request_queue)
# Set up the job scheduler
self.image_job_scheduler = FIFOImageScheduler(RequestQueue(self.config.image_queue, wait_seconds=0))

# Set up internal queues and monitors
self.region_request_queue = RequestQueue(self.config.region_queue, wait_seconds=10)
self.region_requests_iter = iter(self.region_request_queue)

Expand Down Expand Up @@ -163,41 +165,24 @@ def _process_region_requests(self) -> bool:

def _process_image_requests(self) -> bool:
"""
Processes messages from the image request queue.

This method retrieves and processes image requests from the SQS queue. It validates
the image request, and if valid, passes it to the `ImageRequestHandler` for further
processing. In case of a retryable exception, the request is reset in the queue with
a visibility timeout. If the image request fails due to an error, it is marked as
failed and the appropriate actions are taken.
Processes messages from the image job scheduler.

:raises InvalidImageRequestException: If the image request is found to be invalid.
:raises Exception: If an unexpected error occurs during processing.

:return: True if a image request was processed, False if not.
:return: True if an image request was processed, False if not.
"""
logger.debug("Checking work queue for images to process...")
receipt_handle, image_request_message = next(self.image_requests_iter)
image_request = None
if image_request_message:
image_request = self.image_job_scheduler.get_next_scheduled_request()
if image_request:
try:
image_request = ImageRequest.from_external_message(image_request_message)
ThreadingLocalContextFilter.set_context(image_request.__dict__)

if not image_request.is_valid():
raise InvalidImageRequestException(f"Invalid image request: {image_request_message}")

logger.info(f"Starting processing for image request: {image_request.job_id}")
self.image_request_handler.process_image_request(image_request)
self.image_request_queue.finish_request(receipt_handle)
self.image_job_scheduler.finish_request(image_request)
except RetryableJobException:
self.image_request_queue.reset_request(receipt_handle, visibility_timeout=0)
self.image_job_scheduler.finish_request(image_request, should_retry=True)
except Exception as err:
logger.error(f"Error processing image request: {err}")
if image_request:
self._fail_image_request(image_request, err)
self.image_request_queue.finish_request(receipt_handle)
finally:
return True
self._fail_image_request(image_request, err)
self.image_job_scheduler.finish_request(image_request)
return True
else:
return False

Expand Down
8 changes: 8 additions & 0 deletions src/aws/osml/model_runner/scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright 2025 Amazon.com, Inc. or its affiliates.

# Telling flake8 to not flag errors in this file. It is normal that these classes are imported but not used in an
# __init__.py file.
# flake8: noqa

from .fifo_image_scheduler import FIFOImageScheduler
from .image_scheduler import ImageScheduler
62 changes: 62 additions & 0 deletions src/aws/osml/model_runner/scheduler/fifo_image_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2025 Amazon.com, Inc. or its affiliates.

import logging
from typing import Optional

from aws.osml.model_runner.api import ImageRequest, InvalidImageRequestException
from aws.osml.model_runner.queue import RequestQueue
from aws.osml.model_runner.scheduler.image_scheduler import ImageScheduler

logger = logging.getLogger(__name__)


class FIFOImageScheduler(ImageScheduler):
"""
This first in first out (FIFO) scheduler is just a pass through to a request queue.
"""

def __init__(self, image_request_queue: RequestQueue):
self.image_request_queue = image_request_queue
self.image_requests_iter = iter(self.image_request_queue)
self.job_id_to_message_handle = {}

def get_next_scheduled_request(self) -> Optional[ImageRequest]:
"""
Return the next image request to be processed. This implementation retrieves the next message on the queue
and returns the image request created from that message.

:return: the next image request, None if there is not a request pending execution
"""
logger.debug("FIFO image scheduler checking work queue for images to process...")
try:
receipt_handle, image_request_message = next(self.image_requests_iter)
if image_request_message:
try:
image_request = ImageRequest.from_external_message(image_request_message)
if not image_request.is_valid():
raise InvalidImageRequestException(f"Invalid image request: {image_request_message}")

self.job_id_to_message_handle[image_request.job_id] = receipt_handle
return image_request
except Exception:
logger.error("Failed to parse image request", exc_info=True)
self.image_request_queue.finish_request(receipt_handle)
except Exception:
logger.error("Unable to retrieve an image request from the queue", exc_info=True)

return None

def finish_request(self, image_request: ImageRequest, should_retry: bool = False) -> None:
"""
Mark the given image request as finished.

:param image_request: the image request
:param should_retry: true if this request was not complete and can be retried immediately
"""
logger.debug(f"Finished processing image request: {image_request}")
receipt_handle = self.job_id_to_message_handle[image_request.job_id]
if should_retry:
self.image_request_queue.reset_request(receipt_handle, visibility_timeout=0)
else:
self.image_request_queue.finish_request(receipt_handle)
del self.job_id_to_message_handle[image_request.job_id]
29 changes: 29 additions & 0 deletions src/aws/osml/model_runner/scheduler/image_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2025 Amazon.com, Inc. or its affiliates.

from abc import ABC, abstractmethod
from typing import Optional

from aws.osml.model_runner.api import ImageRequest


class ImageScheduler(ABC):
"""
ImageSchedule defines an abstract base for classes that determine how to schedule images for processing.
"""

@abstractmethod
def get_next_scheduled_request(self) -> Optional[ImageRequest]:
"""
Return the next image request to be processed.

:return: the image reqeust
"""

@abstractmethod
def finish_request(self, image_request: ImageRequest, should_retry: bool = False) -> None:
"""
Mark the given image request as finished.

:param image_request: the image request
:param should_retry: true if this request was not complete and can be retried immediately
"""
184 changes: 184 additions & 0 deletions test/aws/osml/model_runner/scheduler/test_fifo_image_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright 2025 Amazon.com, Inc. or its affiliates.

import unittest

from aws.osml.model_runner.api.image_request import ImageRequest
from aws.osml.model_runner.api.inference import ModelInvokeMode
from aws.osml.model_runner.scheduler.fifo_image_scheduler import FIFOImageScheduler


class MockRequestQueue:
def __init__(self):
self.messages = []
self.finished_receipts = set()
self.reset_receipts = set()

def add_message(self, receipt_handle, message):
self.messages.append((receipt_handle, message))

def finish_request(self, receipt_handle):
self.finished_receipts.add(receipt_handle)

def reset_request(self, receipt_handle, visibility_timeout=0):
self.reset_receipts.add(receipt_handle)

def __iter__(self):
return iter(self.messages)


class TestFIFOImageScheduler(unittest.TestCase):
def setUp(self):
self.mock_queue = MockRequestQueue()
self.scheduler = FIFOImageScheduler(self.mock_queue)

def test_get_next_scheduled_request_success(self):
"""Test successful retrieval of next scheduled request"""
# Setup
test_receipt_handle = "receipt-123"
test_message = {
"jobName": "test-job-name",
"jobId": "job-123",
"imageUrls": ["test-image-url"],
"outputs": [
{"type": "S3", "bucket": "test-bucket", "prefix": "test-bucket-prefix"},
{"type": "Kinesis", "stream": "test-stream", "batchSize": 1000},
],
"imageProcessor": {"name": "test-model", "type": "SM_ENDPOINT"},
"imageProcessorTileSize": 1024,
"imageProcessorTileOverlap": 50,
}
self.mock_queue.add_message(test_receipt_handle, test_message)

# Execute
result = self.scheduler.get_next_scheduled_request()

# Assert
self.assertIsInstance(result, ImageRequest)
self.assertTrue(result.is_valid())
self.assertEqual(result.job_id, "job-123")

def test_get_next_scheduled_request_empty_queue(self):
"""Test behavior when queue is empty"""
# Execute
result = self.scheduler.get_next_scheduled_request()

# Assert
self.assertIsNone(result)

def test_get_next_scheduled_request_invalid_request(self):
"""Test handling of invalid image request"""
# Setup
test_receipt_handle = "receipt-123"
test_message = {
"jobId": "job-123",
# Missing required fields to make it invalid
}
self.mock_queue.add_message(test_receipt_handle, test_message)

# Execute
result = self.scheduler.get_next_scheduled_request()

# Assert
self.assertIsNone(result)
self.assertIn(test_receipt_handle, self.mock_queue.finished_receipts)

def test_finish_request_success(self):
"""Test successful completion of request"""
# Setup
test_receipt_handle = "receipt-123"
test_image_request = ImageRequest(
job_id="job-123",
image_id="image-123",
image_url="s3://bucket/image.tif",
image_read_role="arn:aws:iam::123456789012:role/read-role",
outputs=[{"sink_type": "s3", "url": "s3://bucket/output/"}],
model_name="test-model",
model_invoke_mode=ModelInvokeMode.SM_ENDPOINT,
model_invocation_role="arn:aws:iam::123456789012:role/invoke-role",
)
self.scheduler.job_id_to_message_handle["job-123"] = test_receipt_handle

# Execute
self.scheduler.finish_request(test_image_request)

# Assert
self.assertIn(test_receipt_handle, self.mock_queue.finished_receipts)
self.assertNotIn("job-123", self.scheduler.job_id_to_message_handle)

def test_finish_request_with_retry(self):
"""Test finishing request with retry flag"""
# Setup
test_receipt_handle = "receipt-123"
test_image_request = ImageRequest(
job_id="job-123",
image_id="image-123",
image_url="s3://bucket/image.tif",
image_read_role="arn:aws:iam::123456789012:role/read-role",
outputs=[{"sink_type": "s3", "url": "s3://bucket/output/"}],
model_name="test-model",
model_invoke_mode=ModelInvokeMode.SM_ENDPOINT,
model_invocation_role="arn:aws:iam::123456789012:role/invoke-role",
)
self.scheduler.job_id_to_message_handle["job-123"] = test_receipt_handle

# Execute
self.scheduler.finish_request(test_image_request, should_retry=True)

# Assert
self.assertIn(test_receipt_handle, self.mock_queue.reset_receipts)
self.assertNotIn("job-123", self.scheduler.job_id_to_message_handle)

def test_multiple_requests_in_queue(self):
"""Test handling multiple requests in the queue"""
# Setup
test_messages = [
(
"receipt-1",
{
"jobName": "test-job-name",
"jobId": "job-1",
"imageUrls": ["test-image-url"],
"outputs": [
{"type": "S3", "bucket": "test-bucket", "prefix": "test-bucket-prefix"},
{"type": "Kinesis", "stream": "test-stream", "batchSize": 1000},
],
"imageProcessor": {"name": "test-model", "type": "SM_ENDPOINT"},
"imageProcessorTileSize": 1024,
"imageProcessorTileOverlap": 50,
},
),
(
"receipt-2",
{
"jobName": "test-job-name",
"jobId": "job-2",
"imageUrls": ["test-image-url"],
"outputs": [
{"type": "S3", "bucket": "test-bucket", "prefix": "test-bucket-prefix"},
{"type": "Kinesis", "stream": "test-stream", "batchSize": 1000},
],
"imageProcessor": {"name": "test-model", "type": "SM_ENDPOINT"},
"imageProcessorTileSize": 1024,
"imageProcessorTileOverlap": 50,
},
),
]

for receipt, message in test_messages:
self.mock_queue.add_message(receipt, message)

# Execute and Assert first request
first_request = self.scheduler.get_next_scheduled_request()
self.assertIsInstance(first_request, ImageRequest)
self.assertEqual(first_request.job_id, "job-1")
self.assertEqual(self.scheduler.job_id_to_message_handle["job-1"], "receipt-1")

# Execute and Assert second request
second_request = self.scheduler.get_next_scheduled_request()
self.assertIsInstance(second_request, ImageRequest)
self.assertEqual(second_request.job_id, "job-2")
self.assertEqual(self.scheduler.job_id_to_message_handle["job-2"], "receipt-2")


if __name__ == "__main__":
unittest.main()
Loading
Loading