diff --git a/.github/workflows/tox.yml b/.github/workflows/tox.yml index 3b0f577f..54c3627f 100644 --- a/.github/workflows/tox.yml +++ b/.github/workflows/tox.yml @@ -16,6 +16,7 @@ jobs: steps: - uses: actions/checkout@v4 with: + fetch-depth: '0' lfs: 'true' - uses: actions/setup-python@v5 with: diff --git a/src/aws/osml/model_runner/model_runner.py b/src/aws/osml/model_runner/model_runner.py index caaa952d..160e79ad 100644 --- a/src/aws/osml/model_runner/model_runner.py +++ b/src/aws/osml/model_runner/model_runner.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 Amazon.com, Inc. or its affiliates. +# Copyright 2023-2025 Amazon.com, Inc. or its affiliates. import logging @@ -7,7 +7,7 @@ from aws.osml.gdal import load_gdal_dataset, set_gdal_default_configuration from aws.osml.model_runner.api import get_image_path -from .api import ImageRequest, InvalidImageRequestException, RegionRequest +from .api import ImageRequest, RegionRequest from .app_config import ServiceConfig from .common import EndpointUtils, ThreadingLocalContextFilter from .database import EndpointStatisticsTable, JobItem, JobTable, RegionRequestItem, RegionRequestTable @@ -15,6 +15,7 @@ from .image_request_handler import ImageRequestHandler from .queue import RequestQueue from .region_request_handler import RegionRequestHandler +from .scheduler.fifo_image_scheduler import FIFOImageScheduler from .status import ImageStatusMonitor, RegionStatusMonitor from .tile_worker import TilingStrategy, VariableOverlapTilingStrategy @@ -41,9 +42,10 @@ def __init__(self, tiling_strategy: TilingStrategy = VariableOverlapTilingStrate self.config = ServiceConfig() self.tiling_strategy = tiling_strategy - # Set up queues and monitors - self.image_request_queue = RequestQueue(self.config.image_queue, wait_seconds=0) - self.image_requests_iter = iter(self.image_request_queue) + # Set up the job scheduler + self.image_job_scheduler = FIFOImageScheduler(RequestQueue(self.config.image_queue, wait_seconds=0)) + + # Set up internal queues and monitors self.region_request_queue = RequestQueue(self.config.region_queue, wait_seconds=10) self.region_requests_iter = iter(self.region_request_queue) @@ -163,41 +165,24 @@ def _process_region_requests(self) -> bool: def _process_image_requests(self) -> bool: """ - Processes messages from the image request queue. - - This method retrieves and processes image requests from the SQS queue. It validates - the image request, and if valid, passes it to the `ImageRequestHandler` for further - processing. In case of a retryable exception, the request is reset in the queue with - a visibility timeout. If the image request fails due to an error, it is marked as - failed and the appropriate actions are taken. + Processes messages from the image job scheduler. - :raises InvalidImageRequestException: If the image request is found to be invalid. - :raises Exception: If an unexpected error occurs during processing. - - :return: True if a image request was processed, False if not. + :return: True if an image request was processed, False if not. """ - logger.debug("Checking work queue for images to process...") - receipt_handle, image_request_message = next(self.image_requests_iter) - image_request = None - if image_request_message: + image_request = self.image_job_scheduler.get_next_scheduled_request() + if image_request: try: - image_request = ImageRequest.from_external_message(image_request_message) ThreadingLocalContextFilter.set_context(image_request.__dict__) - - if not image_request.is_valid(): - raise InvalidImageRequestException(f"Invalid image request: {image_request_message}") - + logger.info(f"Starting processing for image request: {image_request.job_id}") self.image_request_handler.process_image_request(image_request) - self.image_request_queue.finish_request(receipt_handle) + self.image_job_scheduler.finish_request(image_request) except RetryableJobException: - self.image_request_queue.reset_request(receipt_handle, visibility_timeout=0) + self.image_job_scheduler.finish_request(image_request, should_retry=True) except Exception as err: logger.error(f"Error processing image request: {err}") - if image_request: - self._fail_image_request(image_request, err) - self.image_request_queue.finish_request(receipt_handle) - finally: - return True + self._fail_image_request(image_request, err) + self.image_job_scheduler.finish_request(image_request) + return True else: return False diff --git a/src/aws/osml/model_runner/scheduler/__init__.py b/src/aws/osml/model_runner/scheduler/__init__.py new file mode 100644 index 00000000..ae69bbca --- /dev/null +++ b/src/aws/osml/model_runner/scheduler/__init__.py @@ -0,0 +1,8 @@ +# Copyright 2025 Amazon.com, Inc. or its affiliates. + +# Telling flake8 to not flag errors in this file. It is normal that these classes are imported but not used in an +# __init__.py file. +# flake8: noqa + +from .fifo_image_scheduler import FIFOImageScheduler +from .image_scheduler import ImageScheduler diff --git a/src/aws/osml/model_runner/scheduler/fifo_image_scheduler.py b/src/aws/osml/model_runner/scheduler/fifo_image_scheduler.py new file mode 100644 index 00000000..850d247b --- /dev/null +++ b/src/aws/osml/model_runner/scheduler/fifo_image_scheduler.py @@ -0,0 +1,62 @@ +# Copyright 2025 Amazon.com, Inc. or its affiliates. + +import logging +from typing import Optional + +from aws.osml.model_runner.api import ImageRequest, InvalidImageRequestException +from aws.osml.model_runner.queue import RequestQueue +from aws.osml.model_runner.scheduler.image_scheduler import ImageScheduler + +logger = logging.getLogger(__name__) + + +class FIFOImageScheduler(ImageScheduler): + """ + This first in first out (FIFO) scheduler is just a pass through to a request queue. + """ + + def __init__(self, image_request_queue: RequestQueue): + self.image_request_queue = image_request_queue + self.image_requests_iter = iter(self.image_request_queue) + self.job_id_to_message_handle = {} + + def get_next_scheduled_request(self) -> Optional[ImageRequest]: + """ + Return the next image request to be processed. This implementation retrieves the next message on the queue + and returns the image request created from that message. + + :return: the next image request, None if there is not a request pending execution + """ + logger.debug("FIFO image scheduler checking work queue for images to process...") + try: + receipt_handle, image_request_message = next(self.image_requests_iter) + if image_request_message: + try: + image_request = ImageRequest.from_external_message(image_request_message) + if not image_request.is_valid(): + raise InvalidImageRequestException(f"Invalid image request: {image_request_message}") + + self.job_id_to_message_handle[image_request.job_id] = receipt_handle + return image_request + except Exception: + logger.error("Failed to parse image request", exc_info=True) + self.image_request_queue.finish_request(receipt_handle) + except Exception: + logger.error("Unable to retrieve an image request from the queue", exc_info=True) + + return None + + def finish_request(self, image_request: ImageRequest, should_retry: bool = False) -> None: + """ + Mark the given image request as finished. + + :param image_request: the image request + :param should_retry: true if this request was not complete and can be retried immediately + """ + logger.debug(f"Finished processing image request: {image_request}") + receipt_handle = self.job_id_to_message_handle[image_request.job_id] + if should_retry: + self.image_request_queue.reset_request(receipt_handle, visibility_timeout=0) + else: + self.image_request_queue.finish_request(receipt_handle) + del self.job_id_to_message_handle[image_request.job_id] diff --git a/src/aws/osml/model_runner/scheduler/image_scheduler.py b/src/aws/osml/model_runner/scheduler/image_scheduler.py new file mode 100644 index 00000000..cf70c795 --- /dev/null +++ b/src/aws/osml/model_runner/scheduler/image_scheduler.py @@ -0,0 +1,29 @@ +# Copyright 2025 Amazon.com, Inc. or its affiliates. + +from abc import ABC, abstractmethod +from typing import Optional + +from aws.osml.model_runner.api import ImageRequest + + +class ImageScheduler(ABC): + """ + ImageSchedule defines an abstract base for classes that determine how to schedule images for processing. + """ + + @abstractmethod + def get_next_scheduled_request(self) -> Optional[ImageRequest]: + """ + Return the next image request to be processed. + + :return: the image reqeust + """ + + @abstractmethod + def finish_request(self, image_request: ImageRequest, should_retry: bool = False) -> None: + """ + Mark the given image request as finished. + + :param image_request: the image request + :param should_retry: true if this request was not complete and can be retried immediately + """ diff --git a/test/aws/osml/model_runner/scheduler/test_fifo_image_scheduler.py b/test/aws/osml/model_runner/scheduler/test_fifo_image_scheduler.py new file mode 100644 index 00000000..ac129a06 --- /dev/null +++ b/test/aws/osml/model_runner/scheduler/test_fifo_image_scheduler.py @@ -0,0 +1,184 @@ +# Copyright 2025 Amazon.com, Inc. or its affiliates. + +import unittest + +from aws.osml.model_runner.api.image_request import ImageRequest +from aws.osml.model_runner.api.inference import ModelInvokeMode +from aws.osml.model_runner.scheduler.fifo_image_scheduler import FIFOImageScheduler + + +class MockRequestQueue: + def __init__(self): + self.messages = [] + self.finished_receipts = set() + self.reset_receipts = set() + + def add_message(self, receipt_handle, message): + self.messages.append((receipt_handle, message)) + + def finish_request(self, receipt_handle): + self.finished_receipts.add(receipt_handle) + + def reset_request(self, receipt_handle, visibility_timeout=0): + self.reset_receipts.add(receipt_handle) + + def __iter__(self): + return iter(self.messages) + + +class TestFIFOImageScheduler(unittest.TestCase): + def setUp(self): + self.mock_queue = MockRequestQueue() + self.scheduler = FIFOImageScheduler(self.mock_queue) + + def test_get_next_scheduled_request_success(self): + """Test successful retrieval of next scheduled request""" + # Setup + test_receipt_handle = "receipt-123" + test_message = { + "jobName": "test-job-name", + "jobId": "job-123", + "imageUrls": ["test-image-url"], + "outputs": [ + {"type": "S3", "bucket": "test-bucket", "prefix": "test-bucket-prefix"}, + {"type": "Kinesis", "stream": "test-stream", "batchSize": 1000}, + ], + "imageProcessor": {"name": "test-model", "type": "SM_ENDPOINT"}, + "imageProcessorTileSize": 1024, + "imageProcessorTileOverlap": 50, + } + self.mock_queue.add_message(test_receipt_handle, test_message) + + # Execute + result = self.scheduler.get_next_scheduled_request() + + # Assert + self.assertIsInstance(result, ImageRequest) + self.assertTrue(result.is_valid()) + self.assertEqual(result.job_id, "job-123") + + def test_get_next_scheduled_request_empty_queue(self): + """Test behavior when queue is empty""" + # Execute + result = self.scheduler.get_next_scheduled_request() + + # Assert + self.assertIsNone(result) + + def test_get_next_scheduled_request_invalid_request(self): + """Test handling of invalid image request""" + # Setup + test_receipt_handle = "receipt-123" + test_message = { + "jobId": "job-123", + # Missing required fields to make it invalid + } + self.mock_queue.add_message(test_receipt_handle, test_message) + + # Execute + result = self.scheduler.get_next_scheduled_request() + + # Assert + self.assertIsNone(result) + self.assertIn(test_receipt_handle, self.mock_queue.finished_receipts) + + def test_finish_request_success(self): + """Test successful completion of request""" + # Setup + test_receipt_handle = "receipt-123" + test_image_request = ImageRequest( + job_id="job-123", + image_id="image-123", + image_url="s3://bucket/image.tif", + image_read_role="arn:aws:iam::123456789012:role/read-role", + outputs=[{"sink_type": "s3", "url": "s3://bucket/output/"}], + model_name="test-model", + model_invoke_mode=ModelInvokeMode.SM_ENDPOINT, + model_invocation_role="arn:aws:iam::123456789012:role/invoke-role", + ) + self.scheduler.job_id_to_message_handle["job-123"] = test_receipt_handle + + # Execute + self.scheduler.finish_request(test_image_request) + + # Assert + self.assertIn(test_receipt_handle, self.mock_queue.finished_receipts) + self.assertNotIn("job-123", self.scheduler.job_id_to_message_handle) + + def test_finish_request_with_retry(self): + """Test finishing request with retry flag""" + # Setup + test_receipt_handle = "receipt-123" + test_image_request = ImageRequest( + job_id="job-123", + image_id="image-123", + image_url="s3://bucket/image.tif", + image_read_role="arn:aws:iam::123456789012:role/read-role", + outputs=[{"sink_type": "s3", "url": "s3://bucket/output/"}], + model_name="test-model", + model_invoke_mode=ModelInvokeMode.SM_ENDPOINT, + model_invocation_role="arn:aws:iam::123456789012:role/invoke-role", + ) + self.scheduler.job_id_to_message_handle["job-123"] = test_receipt_handle + + # Execute + self.scheduler.finish_request(test_image_request, should_retry=True) + + # Assert + self.assertIn(test_receipt_handle, self.mock_queue.reset_receipts) + self.assertNotIn("job-123", self.scheduler.job_id_to_message_handle) + + def test_multiple_requests_in_queue(self): + """Test handling multiple requests in the queue""" + # Setup + test_messages = [ + ( + "receipt-1", + { + "jobName": "test-job-name", + "jobId": "job-1", + "imageUrls": ["test-image-url"], + "outputs": [ + {"type": "S3", "bucket": "test-bucket", "prefix": "test-bucket-prefix"}, + {"type": "Kinesis", "stream": "test-stream", "batchSize": 1000}, + ], + "imageProcessor": {"name": "test-model", "type": "SM_ENDPOINT"}, + "imageProcessorTileSize": 1024, + "imageProcessorTileOverlap": 50, + }, + ), + ( + "receipt-2", + { + "jobName": "test-job-name", + "jobId": "job-2", + "imageUrls": ["test-image-url"], + "outputs": [ + {"type": "S3", "bucket": "test-bucket", "prefix": "test-bucket-prefix"}, + {"type": "Kinesis", "stream": "test-stream", "batchSize": 1000}, + ], + "imageProcessor": {"name": "test-model", "type": "SM_ENDPOINT"}, + "imageProcessorTileSize": 1024, + "imageProcessorTileOverlap": 50, + }, + ), + ] + + for receipt, message in test_messages: + self.mock_queue.add_message(receipt, message) + + # Execute and Assert first request + first_request = self.scheduler.get_next_scheduled_request() + self.assertIsInstance(first_request, ImageRequest) + self.assertEqual(first_request.job_id, "job-1") + self.assertEqual(self.scheduler.job_id_to_message_handle["job-1"], "receipt-1") + + # Execute and Assert second request + second_request = self.scheduler.get_next_scheduled_request() + self.assertIsInstance(second_request, ImageRequest) + self.assertEqual(second_request.job_id, "job-2") + self.assertEqual(self.scheduler.job_id_to_message_handle["job-2"], "receipt-2") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/aws/osml/model_runner/test_model_runner.py b/test/aws/osml/model_runner/test_model_runner.py index 45705369..357a5458 100644 --- a/test/aws/osml/model_runner/test_model_runner.py +++ b/test/aws/osml/model_runner/test_model_runner.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 Amazon.com, Inc. or its affiliates. +# Copyright 2023-2025 Amazon.com, Inc. or its affiliates. import unittest from unittest.mock import MagicMock, patch @@ -56,54 +56,43 @@ def test_process_region_requests_success(self, mock_load_gdal, mock_finish_reque self.runner.image_request_handler.complete_image_request.assert_called_once() mock_finish_request.assert_called_once_with("receipt_handle") - @patch("aws.osml.model_runner.model_runner.ImageRequest") - @patch("aws.osml.model_runner.model_runner.RequestQueue.finish_request") - def test_process_image_requests_invalid(self, mock_finish_request, mock_image_request): - """Test that invalid image requests raise an InvalidImageRequestException.""" - # Mock invalid image request - mock_image_request_message = MagicMock() - mock_image_request_instance = MagicMock(is_valid=MagicMock(return_value=False)) - mock_image_request.from_external_message.return_value = mock_image_request_instance - self.runner.image_requests_iter = iter([("receipt_handle", mock_image_request_message)]) - - self.runner._process_image_requests() + def test_process_image_request_noimage(self): + """Test path where the scheduler does not return an ImageRequest to process""" + with patch.object(self.runner, "image_job_scheduler", new_callable=MagicMock) as mock_scheduler: + mock_scheduler.get_next_scheduled_request.return_value = None + result = self.runner._process_image_requests() + self.assertFalse(result) - # Ensure request was marked as completed - mock_finish_request.assert_called_once_with("receipt_handle") - - @patch("aws.osml.model_runner.model_runner.ImageRequest") - @patch("aws.osml.model_runner.model_runner.RequestQueue.reset_request") - def test_process_image_requests_retryable(self, mock_reset_request, mock_image_request): + def test_process_image_requests_retryable(self): """Test that a RetryableJobException resets the request.""" - # Mock retryable job exception - mock_image_request_message = MagicMock() - mock_image_request_instance = MagicMock(is_valid=MagicMock(return_value=True)) - mock_image_request.from_external_message.return_value = mock_image_request_instance - self.runner.image_requests_iter = iter([("receipt_handle", mock_image_request_message)]) - self.runner.image_request_handler.process_image_request.side_effect = RetryableJobException() - # Call method - self.runner._process_image_requests() + with patch.object(self.runner, "image_job_scheduler", new_callable=MagicMock) as mock_scheduler, patch.object( + self.runner, "image_request_handler", new_callable=MagicMock + ) as mock_handler: + mock_image_request = MagicMock() + mock_scheduler.get_next_scheduled_request.return_value = mock_image_request + mock_handler.process_image_request.side_effect = RetryableJobException() + + result = self.runner._process_image_requests() + self.assertTrue(result) - # Ensure request was reset - mock_reset_request.assert_called_once_with("receipt_handle", visibility_timeout=0) + mock_scheduler.finish_request.assert_called_once_with(mock_image_request, should_retry=True) @patch("aws.osml.model_runner.model_runner.ImageRequest") def test_process_image_requests_general_error(self, mock_image_request): """Test that general exceptions mark the image request as failed.""" - # Mock exception - mock_image_request_message = MagicMock() - mock_image_request_instance = MagicMock() - self.runner._fail_image_request = MagicMock() - mock_image_request.from_external_message.return_value = mock_image_request_instance - self.runner.image_requests_iter = iter([("receipt_handle", mock_image_request_message)]) - self.runner.image_request_handler.process_image_request.side_effect = Exception("Some error") - # Call method - self.runner._process_image_requests() + with patch.object(self.runner, "image_job_scheduler", new_callable=MagicMock) as mock_scheduler, patch.object( + self.runner, "image_request_handler", new_callable=MagicMock + ) as mock_handler: + mock_image_request = MagicMock() + mock_scheduler.get_next_scheduled_request.return_value = mock_image_request + mock_handler.process_image_request.side_effect = Exception("Some error") + + result = self.runner._process_image_requests() + self.assertTrue(result) - # Ensure image request was failed - self.runner._fail_image_request.assert_called() + mock_scheduler.finish_request.assert_called_once_with(mock_image_request) @patch("aws.osml.model_runner.model_runner.RegionRequestHandler.process_region_request") @patch("aws.osml.model_runner.model_runner.RequestQueue.finish_request") diff --git a/tox.ini b/tox.ini index 30aeb47e..e221655a 100755 --- a/tox.ini +++ b/tox.ini @@ -70,7 +70,7 @@ commands = skip_install = true conda_env = deps = pre-commit -commands = pre-commit run --all-files --show-diff-on-failure +commands = pre-commit run --from-ref origin/main --to-ref HEAD [testenv:docs] changedir = doc