From e667f4b76a75e0791e5eda220256810573cc3ee8 Mon Sep 17 00:00:00 2001 From: drduhe Date: Tue, 15 Oct 2024 13:24:56 -0600 Subject: [PATCH] refactor: rename app.py to model_runner.py --- bin/oversightml-mr-entry-point.py | 2 +- src/aws/osml/model_runner/__init__.py | 2 + src/aws/osml/model_runner/api/__init__.py | 2 +- .../osml/model_runner/api/image_request.py | 38 --- .../osml/model_runner/api/request_utils.py | 59 ++++- src/aws/osml/model_runner/app.py | 210 --------------- .../model_runner/image_request_handler.py | 38 +-- src/aws/osml/model_runner/model_runner.py | 242 ++++++++++++++++++ .../test_image_request_handler.py | 1 + .../osml/model_runner/test_model_runner.py | 128 +++++++++ test/test_api.py | 4 +- test/{test_app.py => test_end_to_end.py} | 4 +- 12 files changed, 451 insertions(+), 279 deletions(-) delete mode 100644 src/aws/osml/model_runner/app.py create mode 100644 src/aws/osml/model_runner/model_runner.py create mode 100644 test/aws/osml/model_runner/test_model_runner.py rename test/{test_app.py => test_end_to_end.py} (99%) diff --git a/bin/oversightml-mr-entry-point.py b/bin/oversightml-mr-entry-point.py index 3f924c7f..dfd10195 100644 --- a/bin/oversightml-mr-entry-point.py +++ b/bin/oversightml-mr-entry-point.py @@ -11,7 +11,7 @@ from codeguru_profiler_agent import Profiler from pythonjsonlogger import jsonlogger -from aws.osml.model_runner.app import ModelRunner +from aws.osml.model_runner import ModelRunner from aws.osml.model_runner.common import ThreadingLocalContextFilter diff --git a/src/aws/osml/model_runner/__init__.py b/src/aws/osml/model_runner/__init__.py index 588610ec..bb365f14 100755 --- a/src/aws/osml/model_runner/__init__.py +++ b/src/aws/osml/model_runner/__init__.py @@ -3,3 +3,5 @@ # Telling flake8 to not flag errors in this file. It is normal that these classes are imported but not used in an # __init__.py file. # flake8: noqa + +from .model_runner import ModelRunner diff --git a/src/aws/osml/model_runner/api/__init__.py b/src/aws/osml/model_runner/api/__init__.py index 2f871f87..091b977e 100755 --- a/src/aws/osml/model_runner/api/__init__.py +++ b/src/aws/osml/model_runner/api/__init__.py @@ -8,5 +8,5 @@ from .image_request import ImageRequest from .inference import VALID_MODEL_HOSTING_OPTIONS, ModelInvokeMode from .region_request import RegionRequest -from .request_utils import shared_properties_are_valid +from .request_utils import get_image_path, shared_properties_are_valid from .sink import SinkMode, SinkType diff --git a/src/aws/osml/model_runner/api/image_request.py b/src/aws/osml/model_runner/api/image_request.py index ac8d056f..ec3a1f2b 100755 --- a/src/aws/osml/model_runner/api/image_request.py +++ b/src/aws/osml/model_runner/api/image_request.py @@ -4,12 +4,10 @@ from json import dumps, loads from typing import Any, Dict, List, Optional -import boto3 import shapely.geometry import shapely.wkt from shapely.geometry.base import BaseGeometry -from aws.osml.model_runner.app_config import BotoConfig from aws.osml.model_runner.common import ( FeatureDistillationAlgorithm, FeatureDistillationNMS, @@ -19,10 +17,8 @@ MRPostProcessing, MRPostprocessingStep, deserialize_post_processing_list, - get_credentials_for_assumed_role, ) -from .exceptions import InvalidS3ObjectException from .inference import ModelInvokeMode from .request_utils import shared_properties_are_valid from .sink import SinkType @@ -193,37 +189,3 @@ def get_feature_distillation_option(self) -> List[FeatureDistillationAlgorithm]: if op.step == MRPostprocessingStep.FEATURE_DISTILLATION and isinstance(op.algorithm, FeatureDistillationAlgorithm) ] - - @staticmethod - def validate_image_path(image_url: str, assumed_role: str) -> bool: - """ - Validate if an image exists in S3 bucket - - :param image_url: str = formatted image path to S3 bucket - :param assumed_role: str = containing a formatted arn role - - :return: bool - """ - bucket, key = image_url.replace("s3://", "").split("/", 1) - if assumed_role: - assumed_credentials = get_credentials_for_assumed_role(assumed_role) - # Here we will be writing to S3 using an IAM role other than the one for this process. - s3_client = boto3.client( - "s3", - aws_access_key_id=assumed_credentials["AccessKeyId"], - aws_secret_access_key=assumed_credentials["SecretAccessKey"], - aws_session_token=assumed_credentials["SessionToken"], - config=BotoConfig.default, - ) - else: - # If no invocation role is provided the assumption is that the default role for this - # container will be sufficient to read/write to the S3 bucket. - s3_client = boto3.client("s3", config=BotoConfig.default) - - try: - # head_object is the fastest approach to determine if it exists in S3 - # also its less expensive to do the head_object approach - s3_client.head_object(Bucket=bucket, Key=key) - return True - except Exception as err: - raise InvalidS3ObjectException("This image does not exist!") from err diff --git a/src/aws/osml/model_runner/api/request_utils.py b/src/aws/osml/model_runner/api/request_utils.py index fa13b5b3..b95393a7 100755 --- a/src/aws/osml/model_runner/api/request_utils.py +++ b/src/aws/osml/model_runner/api/request_utils.py @@ -1,7 +1,11 @@ # Copyright 2023-2024 Amazon.com, Inc. or its affiliates. -from aws.osml.model_runner.common import VALID_IMAGE_COMPRESSION, VALID_IMAGE_FORMATS +import boto3 +from aws.osml.model_runner.app_config import BotoConfig +from aws.osml.model_runner.common import VALID_IMAGE_COMPRESSION, VALID_IMAGE_FORMATS, get_credentials_for_assumed_role + +from .exceptions import InvalidS3ObjectException from .inference import VALID_MODEL_HOSTING_OPTIONS @@ -55,3 +59,56 @@ def shared_properties_are_valid(request) -> bool: return False return True + + +def get_image_path(image_url: str, assumed_role: str) -> str: + """ + Returns the formatted image path for GDAL to read the image, either from S3 or a local file. + + If the image URL points to an S3 path, this method validates the image's existence in S3 + and reformats the path to use GDAL's /vsis3/ driver. Otherwise, it returns the local or + network image path. + + :param image_url: str = formatted image path to S3 bucket + :param assumed_role: str = containing a formatted arn role + + :return: The formatted image path. + """ + if "s3:/" in image_url: + validate_image_path(image_url, assumed_role) + return image_url.replace("s3:/", "/vsis3", 1) + return image_url + + +def validate_image_path(image_url: str, assumed_role: str) -> bool: + """ + Validate if an image exists in S3 bucket + + :param image_url: str = formatted image path to S3 bucket + :param assumed_role: str = containing a formatted arn role + + :return: bool + """ + bucket, key = image_url.replace("s3://", "").split("/", 1) + if assumed_role: + assumed_credentials = get_credentials_for_assumed_role(assumed_role) + # Here we will be writing to S3 using an IAM role other than the one for this process. + s3_client = boto3.client( + "s3", + aws_access_key_id=assumed_credentials["AccessKeyId"], + aws_secret_access_key=assumed_credentials["SecretAccessKey"], + aws_session_token=assumed_credentials["SessionToken"], + config=BotoConfig.default, + ) + else: + # If no invocation role is provided the assumption is that the default role for this + # container will be sufficient to read/write to the S3 bucket. + s3_client = boto3.client("s3", config=BotoConfig.default) + + try: + # head_object is the fastest approach to determine if it exists in S3 + # also its less expensive to do the head_object approach + s3_client.head_object(Bucket=bucket, Key=key) + return True + except Exception as err: + raise InvalidS3ObjectException("This image does not exist!") from err diff --git a/src/aws/osml/model_runner/app.py b/src/aws/osml/model_runner/app.py deleted file mode 100644 index 08ad3f01..00000000 --- a/src/aws/osml/model_runner/app.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright 2023-2024 Amazon.com, Inc. or its affiliates. - -import logging - -from osgeo import gdal - -from aws.osml.gdal import load_gdal_dataset, set_gdal_default_configuration - -from .api import ImageRequest, InvalidImageRequestException, RegionRequest -from .app_config import ServiceConfig -from .common import EndpointUtils, ThreadingLocalContextFilter -from .database import EndpointStatisticsTable, JobItem, JobTable, RegionRequestItem, RegionRequestTable -from .exceptions import RetryableJobException, SelfThrottledRegionException -from .image_request_handler import ImageRequestHandler -from .queue import RequestQueue -from .region_request_handler import RegionRequestHandler -from .status import ImageStatusMonitor, RegionStatusMonitor -from .tile_worker import TilingStrategy, VariableOverlapTilingStrategy - -# Set up logging configuration -logger = logging.getLogger(__name__) - -# GDAL 4.0 will begin using exceptions as the default; at this point the software is written to assume -# no exceptions so we call this explicitly until the software can be updated to match. -gdal.UseExceptions() - - -class ModelRunner: - """ - Main class for operating the ModelRunner application. It monitors input queues for processing requests, - decomposes the image into a set of smaller regions and tiles, invokes an ML model endpoint with each tile, and - finally aggregates all the results into a single output which can be deposited into the desired output sinks. - """ - - def __init__(self, tiling_strategy: TilingStrategy = VariableOverlapTilingStrategy()) -> None: - """ - Initialize a model runner with the injectable behaviors. - - :param tiling_strategy: class defining how a larger image will be broken into chunks for processing - """ - self.config = ServiceConfig() - self.tiling_strategy = tiling_strategy - self.image_request_queue = RequestQueue(self.config.image_queue, wait_seconds=0) - self.image_requests_iter = iter(self.image_request_queue) - self.job_table = JobTable(self.config.job_table) - self.region_request_table = RegionRequestTable(self.config.region_request_table) - self.endpoint_statistics_table = EndpointStatisticsTable(self.config.endpoint_statistics_table) - self.region_request_queue = RequestQueue(self.config.region_queue, wait_seconds=10) - self.region_requests_iter = iter(self.region_request_queue) - self.image_status_monitor = ImageStatusMonitor(self.config.image_status_topic) - self.region_status_monitor = RegionStatusMonitor(self.config.region_status_topic) - self.endpoint_utils = EndpointUtils() - # Pass dependencies into RegionRequestHandler - self.region_request_handler = RegionRequestHandler( - region_request_table=self.region_request_table, - job_table=self.job_table, - region_status_monitor=self.region_status_monitor, - endpoint_statistics_table=self.endpoint_statistics_table, - tiling_strategy=self.tiling_strategy, - endpoint_utils=self.endpoint_utils, - config=self.config, - ) - # Pass dependencies into ImageRequestHandler - self.image_request_handler = ImageRequestHandler( - job_table=self.job_table, - image_status_monitor=self.image_status_monitor, - endpoint_statistics_table=self.endpoint_statistics_table, - tiling_strategy=self.tiling_strategy, - region_request_queue=self.region_request_queue, - region_request_table=self.region_request_table, - endpoint_utils=self.endpoint_utils, - config=self.config, - region_request_handler=self.region_request_handler, - ) - self.running = False - - def run(self) -> None: - """ - Starts ModelRunner in a loop that continuously monitors the image work queue and region work queue. - - :return: None - """ - self.monitor_work_queues() - - def stop(self) -> None: - """ - Stops ModelRunner by setting the global run variable to False. - - :return: None - """ - self.running = False - - def monitor_work_queues(self) -> None: - """ - Monitors SQS queues for ImageRequest and RegionRequest The region work queue is checked first and will wait for - up to 10 seconds to start work. Only if no regions need to be processed in that time will this worker check to - see if a new image can be started. Ultimately this setup is intended to ensure that all the regions for an image - are completed by the cluster before work begins on more images. - - :return: None - """ - # Set the running state to True - self.running = True - - # Set up the GDAL configuration options that should remain unchanged for the life of this execution - set_gdal_default_configuration() - try: - while self.running: - logger.debug("Checking work queue for regions to process ...") - (receipt_handle, region_request_attributes) = next(self.region_requests_iter) - ThreadingLocalContextFilter.set_context(region_request_attributes) - - # If we found a region request on the queue - if region_request_attributes is not None: - try: - # Parse the message into a working RegionRequest - region_request = RegionRequest(region_request_attributes) - - # If the image request has a s3 url lets augment its path for virtual hosting - if "s3:/" in region_request.image_url: - # Validate that image exists in S3 - ImageRequest.validate_image_path(region_request.image_url, region_request.image_read_role) - image_path = region_request.image_url.replace("s3:/", "/vsis3", 1) - else: - image_path = region_request.image_url - - # Load the image into a GDAL dataset - raster_dataset, sensor_model = load_gdal_dataset(image_path) - image_format = str(raster_dataset.GetDriver().ShortName).upper() - - # Get RegionRequestItem if not create new RegionRequestItem - region_request_item = self.region_request_table.get_region_request( - region_request.region_id, region_request.image_id - ) - if region_request_item is None: - # Create a new item from the region request - region_request_item = RegionRequestItem.from_region_request(region_request) - - # Add the item to the table and start it processing - self.region_request_table.start_region_request(region_request_item) - logging.debug( - ( - f"Adding region request: image id: {region_request_item.image_id} - " - f"region id: {region_request_item.region_id}" - ) - ) - - # Process our region request - image_request_item = self.region_request_handler.process_region_request( - region_request, region_request_item, raster_dataset, sensor_model - ) - - # Check if the image is complete - if self.job_table.is_image_request_complete(image_request_item): - # If so complete the image request - self.image_request_handler.complete_image_request( - region_request, image_format, raster_dataset, sensor_model - ) - - # Update the queue - self.region_request_queue.finish_request(receipt_handle) - except RetryableJobException: - self.region_request_queue.reset_request(receipt_handle, visibility_timeout=0) - except SelfThrottledRegionException: - self.region_request_queue.reset_request( - receipt_handle, - visibility_timeout=int(self.config.throttling_retry_timeout), - ) - except Exception as err: - logger.error(f"There was a problem processing the region request: {err}") - self.region_request_queue.finish_request(receipt_handle) - else: - logger.debug("Checking work queue for images to process ...") - (receipt_handle, image_request_message) = next(self.image_requests_iter) - - # If we found a request on the queue - if image_request_message is not None: - image_request = None - try: - # Parse the message into a working ImageRequest - image_request = ImageRequest.from_external_message(image_request_message) - ThreadingLocalContextFilter.set_context(image_request.__dict__) - - # Check that our image request looks good - if not image_request.is_valid(): - error = f"Invalid image request: {image_request_message}" - logger.exception(error) - raise InvalidImageRequestException(error) - - # Process the request - self.image_request_handler.process_image_request(image_request) - - # Update the queue - self.image_request_queue.finish_request(receipt_handle) - except RetryableJobException: - self.image_request_queue.reset_request(receipt_handle, visibility_timeout=0) - except Exception as err: - logger.error(f"There was a problem processing the image request: {err}") - min_image_id = image_request.image_id if image_request and image_request.image_id else "" - min_job_id = image_request.job_id if image_request and image_request.job_id else "" - minimal_job_item = JobItem( - image_id=min_image_id, - job_id=min_job_id, - processing_duration=0, - ) - self.image_request_handler.fail_image_request(minimal_job_item, err) - self.image_request_queue.finish_request(receipt_handle) - finally: - # If we stop monitoring the queue set run state to false - self.running = False diff --git a/src/aws/osml/model_runner/image_request_handler.py b/src/aws/osml/model_runner/image_request_handler.py index 382fe03d..199d35e5 100644 --- a/src/aws/osml/model_runner/image_request_handler.py +++ b/src/aws/osml/model_runner/image_request_handler.py @@ -8,13 +8,14 @@ from typing import List, Optional, Tuple import shapely.geometry.base -from aws_embedded_metrics import MetricsLogger +from aws_embedded_metrics import MetricsLogger, metric_scope from aws_embedded_metrics.unit import Unit from geojson import Feature from osgeo import gdal from osgeo.gdal import Dataset from aws.osml.gdal import GDALConfigEnv, get_image_extension, load_gdal_dataset +from aws.osml.model_runner.api import get_image_path from aws.osml.photogrammetry import SensorModel from .api import VALID_MODEL_HOSTING_OPTIONS, ImageRequest, RegionRequest @@ -32,7 +33,6 @@ from .exceptions import ( AggregateFeaturesException, AggregateOutputFeaturesException, - InvalidImageURLException, LoadImageException, ProcessImageException, UnsupportedModelException, @@ -278,18 +278,8 @@ def load_image_request( # stored in S3) within this "with" statement will be made using customer credentials. At # the end of the "with" scope the credentials will be removed. with GDALConfigEnv().with_aws_credentials(assumed_credentials): - # Use GDAL to access the dataset and geo positioning metadata - if not job_item.image_url: - raise InvalidImageURLException("No image URL specified. Image URL is required.") - - # If the image request has a valid s3 image url, otherwise this is a local file - if "s3:/" in job_item.image_url: - # Validate that image exists in S3 - ImageRequest.validate_image_path(job_item.image_url, job_item.image_read_role) - - image_path = job_item.image_url.replace("s3:/", "/vsis3", 1) - else: - image_path = job_item.image_url + # Extract the virtual image path from the request + image_path = get_image_path(job_item.image_url, job_item.image_read_role) # Use gdal to load the image url we were given raster_dataset, sensor_model = load_gdal_dataset(image_path) @@ -332,12 +322,7 @@ def fail_image_request(self, job_item: JobItem, err: Exception) -> None: self.job_table.end_image_request(job_item.image_id) def complete_image_request( - self, - region_request: RegionRequest, - image_format: str, - raster_dataset: gdal.Dataset, - sensor_model: SensorModel, - metrics: MetricsLogger = None, + self, region_request: RegionRequest, image_format: str, raster_dataset: gdal.Dataset, sensor_model: SensorModel ) -> None: """ Completes the image request after all regions have been processed. Aggregates and sinks the features, @@ -347,7 +332,6 @@ def complete_image_request( :param image_format: The format of the image file. :param raster_dataset: The GDAL dataset of the processed image. :param sensor_model: The sensor model for the image, if available. - :param metrics: Optional metrics logger for performance metrics. :raises AggregateFeaturesException: If feature aggregation fails. :return: None @@ -373,28 +357,32 @@ def complete_image_request( final_features = add_properties_to_features(job_item.job_id, job_item.feature_properties, deduped_features) # Sink features to target outputs - self.sink_features(job_item, final_features, metrics) + self.sink_features(job_item, final_features) # Finalize and update the job table with the completed request - self.end_image_request(job_item, image_format, metrics) + self.end_image_request(job_item, image_format) except Exception as err: raise AggregateFeaturesException("Failed to aggregate features for region!") from err + @metric_scope def deduplicate( self, job_item: JobItem, features: List[Feature], raster_dataset: gdal.Dataset, sensor_model: SensorModel, + metrics: MetricsLogger = None, ) -> List[Feature]: """ Deduplicate the features and add additional properties to them, if applicable. + :param metrics: :param job_item: The image processing job item containing job-specific information. :param features: A list of GeoJSON features to deduplicate. :param raster_dataset: The GDAL dataset representing the image being processed. :param sensor_model: The sensor model associated with the dataset, used for georeferencing. + :param metrics: Optional metrics logger for tracking performance metrics. :return: A list of deduplicated features with additional properties added. """ @@ -402,7 +390,7 @@ def deduplicate( task_str="Select (deduplicate) image features", metric_name=MetricLabels.DURATION, logger=logger, - metrics_logger=None, + metrics_logger=metrics, ): # Calculate processing bounds based on the region of interest (ROI) and sensor model processing_bounds = self.calculate_processing_bounds(raster_dataset, sensor_model, job_item.roi_wkt) @@ -438,6 +426,7 @@ def validate_model_hosting(self, image_request: JobItem): ) raise UnsupportedModelException(error) + @metric_scope def end_image_request(self, job_item: JobItem, image_format: str, metrics: MetricsLogger = None) -> None: """ Finalizes the image request, updates the job status, and logs the necessary metrics. @@ -499,6 +488,7 @@ def calculate_processing_bounds( return processing_bounds @staticmethod + @metric_scope def sink_features(job_item: JobItem, features: List[Feature], metrics: MetricsLogger = None) -> None: """ Sink the deduplicated features to the specified output (e.g., S3, Kinesis, etc.). diff --git a/src/aws/osml/model_runner/model_runner.py b/src/aws/osml/model_runner/model_runner.py new file mode 100644 index 00000000..caaa952d --- /dev/null +++ b/src/aws/osml/model_runner/model_runner.py @@ -0,0 +1,242 @@ +# Copyright 2023-2024 Amazon.com, Inc. or its affiliates. + +import logging + +from osgeo import gdal + +from aws.osml.gdal import load_gdal_dataset, set_gdal_default_configuration +from aws.osml.model_runner.api import get_image_path + +from .api import ImageRequest, InvalidImageRequestException, RegionRequest +from .app_config import ServiceConfig +from .common import EndpointUtils, ThreadingLocalContextFilter +from .database import EndpointStatisticsTable, JobItem, JobTable, RegionRequestItem, RegionRequestTable +from .exceptions import RetryableJobException, SelfThrottledRegionException +from .image_request_handler import ImageRequestHandler +from .queue import RequestQueue +from .region_request_handler import RegionRequestHandler +from .status import ImageStatusMonitor, RegionStatusMonitor +from .tile_worker import TilingStrategy, VariableOverlapTilingStrategy + +# Set up logging configuration +logger = logging.getLogger(__name__) +gdal.UseExceptions() + + +class ModelRunner: + """ + Main class for operating the ModelRunner application. It monitors input queues for processing requests, + decomposes the image into smaller regions and tiles, invokes an ML model on each tile, and aggregates + the results into a single output, which can be sent to the configured output sinks. + """ + + def __init__(self, tiling_strategy: TilingStrategy = VariableOverlapTilingStrategy()) -> None: + """ + Initialize a model runner with the injectable behaviors. + + :param tiling_strategy: Defines how a larger image will be broken into chunks for processing + + :return: None + """ + self.config = ServiceConfig() + self.tiling_strategy = tiling_strategy + + # Set up queues and monitors + self.image_request_queue = RequestQueue(self.config.image_queue, wait_seconds=0) + self.image_requests_iter = iter(self.image_request_queue) + self.region_request_queue = RequestQueue(self.config.region_queue, wait_seconds=10) + self.region_requests_iter = iter(self.region_request_queue) + + # Set up tables and status monitors + self.job_table = JobTable(self.config.job_table) + self.region_request_table = RegionRequestTable(self.config.region_request_table) + self.endpoint_statistics_table = EndpointStatisticsTable(self.config.endpoint_statistics_table) + self.image_status_monitor = ImageStatusMonitor(self.config.image_status_topic) + self.region_status_monitor = RegionStatusMonitor(self.config.region_status_topic) + self.endpoint_utils = EndpointUtils() + + # Handlers for image and region processing + self.region_request_handler = RegionRequestHandler( + region_request_table=self.region_request_table, + job_table=self.job_table, + region_status_monitor=self.region_status_monitor, + endpoint_statistics_table=self.endpoint_statistics_table, + tiling_strategy=self.tiling_strategy, + endpoint_utils=self.endpoint_utils, + config=self.config, + ) + self.image_request_handler = ImageRequestHandler( + job_table=self.job_table, + image_status_monitor=self.image_status_monitor, + endpoint_statistics_table=self.endpoint_statistics_table, + tiling_strategy=self.tiling_strategy, + region_request_queue=self.region_request_queue, + region_request_table=self.region_request_table, + endpoint_utils=self.endpoint_utils, + config=self.config, + region_request_handler=self.region_request_handler, + ) + + self.running = False + + def run(self) -> None: + """ + Start the ModelRunner to continuously monitor and process work queues. + + :return: None + """ + logger.info("Starting ModelRunner") + self.running = True + self.monitor_work_queues() + + def stop(self) -> None: + """ + Stop the ModelRunner. + + :return: None + """ + logger.info("Stopping ModelRunner") + self.running = False + + def monitor_work_queues(self) -> None: + """ + Continuously monitors the SQS queues for RegionRequest and ImageRequest. + :return: None + """ + set_gdal_default_configuration() + logger.info("Beginning monitoring request queues") + while self.running: + try: + # If there are regions to process + if not self._process_region_requests(): + # Move along to the next image request if present + self._process_image_requests() + except Exception as err: + logger.error(f"Unexpected error in monitor_work_queues: {err}") + self.running = False + logger.info("Stopped monitoring request queues") + + def _process_region_requests(self) -> bool: + """ + Process messages from the region request queue. + + :return: True if a region request was processed, False if not. + """ + logger.debug("Checking work queue for regions to process...") + try: + receipt_handle, region_request_attributes = next(self.region_requests_iter) + except StopIteration: + # No region requests available in the queue + logger.debug("No region requests available to process.") + return False + + if region_request_attributes: + ThreadingLocalContextFilter.set_context(region_request_attributes) + try: + region_request = RegionRequest(region_request_attributes) + image_path = get_image_path(region_request.image_url, region_request.image_read_role) + raster_dataset, sensor_model = load_gdal_dataset(image_path) + region_request_item = self._get_or_create_region_request_item(region_request) + image_request_item = self.region_request_handler.process_region_request( + region_request, region_request_item, raster_dataset, sensor_model + ) + if self.job_table.is_image_request_complete(image_request_item): + self.image_request_handler.complete_image_request( + region_request, str(raster_dataset.GetDriver().ShortName).upper(), raster_dataset, sensor_model + ) + self.region_request_queue.finish_request(receipt_handle) + except RetryableJobException as err: + logger.warning(f"Retrying region request due to: {err}") + self.region_request_queue.reset_request(receipt_handle, visibility_timeout=0) + except SelfThrottledRegionException as err: + logger.warning(f"Retrying region request due to: {err}") + self.region_request_queue.reset_request( + receipt_handle, visibility_timeout=int(self.config.throttling_retry_timeout) + ) + except Exception as err: + logger.exception(f"Error processing region request: {err}") + self.region_request_queue.finish_request(receipt_handle) + finally: + return True + else: + return False + + def _process_image_requests(self) -> bool: + """ + Processes messages from the image request queue. + + This method retrieves and processes image requests from the SQS queue. It validates + the image request, and if valid, passes it to the `ImageRequestHandler` for further + processing. In case of a retryable exception, the request is reset in the queue with + a visibility timeout. If the image request fails due to an error, it is marked as + failed and the appropriate actions are taken. + + :raises InvalidImageRequestException: If the image request is found to be invalid. + :raises Exception: If an unexpected error occurs during processing. + + :return: True if a image request was processed, False if not. + """ + logger.debug("Checking work queue for images to process...") + receipt_handle, image_request_message = next(self.image_requests_iter) + image_request = None + if image_request_message: + try: + image_request = ImageRequest.from_external_message(image_request_message) + ThreadingLocalContextFilter.set_context(image_request.__dict__) + + if not image_request.is_valid(): + raise InvalidImageRequestException(f"Invalid image request: {image_request_message}") + + self.image_request_handler.process_image_request(image_request) + self.image_request_queue.finish_request(receipt_handle) + except RetryableJobException: + self.image_request_queue.reset_request(receipt_handle, visibility_timeout=0) + except Exception as err: + logger.error(f"Error processing image request: {err}") + if image_request: + self._fail_image_request(image_request, err) + self.image_request_queue.finish_request(receipt_handle) + finally: + return True + else: + return False + + def _fail_image_request(self, image_request: ImageRequest, error: Exception) -> None: + """ + Handles failing an image request by updating its status and logging the failure. + + This method is called when an image request cannot be processed due to an error. + It marks the image request as failed and updates the job status using the + `ImageRequestHandler`. + + :param image_request: The image request that failed to process. + :param error: The exception that caused the failure. + + :return: None + """ + min_image_id = image_request.image_id if image_request else "" + min_job_id = image_request.job_id if image_request else "" + minimal_job_item = JobItem(image_id=min_image_id, job_id=min_job_id, processing_duration=0) + self.image_request_handler.fail_image_request(minimal_job_item, error) + + def _get_or_create_region_request_item(self, region_request: RegionRequest) -> RegionRequestItem: + """ + Retrieves or creates a `RegionRequestItem` in the region request table. + + This method checks if a region request already exists in the `RegionRequestTable`. + If it does, it retrieves the existing request; otherwise, it creates a new + `RegionRequestItem` from the provided `RegionRequest` and starts the region + processing. + + :param region_request: The region request to process. + + :return: The retrieved or newly created `RegionRequestItem`. + """ + region_request_item = self.region_request_table.get_region_request(region_request.region_id, region_request.image_id) + if region_request_item is None: + region_request_item = RegionRequestItem.from_region_request(region_request) + self.region_request_table.start_region_request(region_request_item) + logger.debug( + f"Added region request: image id {region_request_item.image_id}, region id {region_request_item.region_id}" + ) + return region_request_item diff --git a/test/aws/osml/model_runner/test_image_request_handler.py b/test/aws/osml/model_runner/test_image_request_handler.py index 41b7811c..ba685ebf 100644 --- a/test/aws/osml/model_runner/test_image_request_handler.py +++ b/test/aws/osml/model_runner/test_image_request_handler.py @@ -140,6 +140,7 @@ def test_complete_image_request(self, mock_aggregate_features, mock_deduplicate, mock_deduplicate.return_value = mock_features mock_aggregate_features.return_value = mock_features mock_sink_features.return_value = True + self.mock_job_item.processing_duration = 1000 # Call complete_image_request self.handler.complete_image_request(mock_region_request, "tif", mock_raster_dataset, mock_sensor_model) diff --git a/test/aws/osml/model_runner/test_model_runner.py b/test/aws/osml/model_runner/test_model_runner.py new file mode 100644 index 00000000..45705369 --- /dev/null +++ b/test/aws/osml/model_runner/test_model_runner.py @@ -0,0 +1,128 @@ +# Copyright 2023-2024 Amazon.com, Inc. or its affiliates. + +import unittest +from unittest.mock import MagicMock, patch + +from aws.osml.model_runner.model_runner import ModelRunner, RetryableJobException + + +class TestModelRunner(unittest.TestCase): + def setUp(self): + # Create the instance of ModelRunner + self.runner = ModelRunner() + # Mock the process methods + self.runner.image_request_handler = MagicMock() + self.runner.region_request_handler = MagicMock() + + def test_run_starts_monitoring(self): + """Test that the `run` method sets up and starts the monitoring loop.""" + # Mock method calls + self.runner.monitor_work_queues = MagicMock() + + # Call the method + self.runner.run() + + # Ensure the run method calls monitor_work_queues and sets `self.running` + self.assertTrue(self.runner.running) + self.runner.monitor_work_queues.assert_called_once() + + def test_stop_stops_running(self): + """Test that the `stop` method correctly stops the runner.""" + # Call stop + self.runner.stop() + + # Check if `self.running` is set to False + self.assertFalse(self.runner.running) + + @patch("aws.osml.model_runner.model_runner.RegionRequestHandler.process_region_request") + @patch("aws.osml.model_runner.model_runner.RequestQueue.finish_request") + @patch("aws.osml.model_runner.model_runner.load_gdal_dataset") + def test_process_region_requests_success(self, mock_load_gdal, mock_finish_request, mock_process_region): + """Test processing of region requests successfully.""" + mock_region_request_item = MagicMock() + mock_image_request_item = MagicMock() + self.runner._get_or_create_region_request_item = MagicMock(return_value=mock_region_request_item) + mock_load_gdal.return_value = (MagicMock(), MagicMock()) + mock_process_region.return_value = mock_image_request_item + self.runner.job_table.is_image_request_complete = MagicMock(return_value=True) + + # Simulate queue data + self.runner.region_requests_iter = iter([("receipt_handle", {"region_id": "region_123"})]) + + # Call method + self.runner._process_region_requests() + + # Ensure region request was processed correctly + self.runner.image_request_handler.complete_image_request.assert_called_once() + mock_finish_request.assert_called_once_with("receipt_handle") + + @patch("aws.osml.model_runner.model_runner.ImageRequest") + @patch("aws.osml.model_runner.model_runner.RequestQueue.finish_request") + def test_process_image_requests_invalid(self, mock_finish_request, mock_image_request): + """Test that invalid image requests raise an InvalidImageRequestException.""" + # Mock invalid image request + mock_image_request_message = MagicMock() + mock_image_request_instance = MagicMock(is_valid=MagicMock(return_value=False)) + mock_image_request.from_external_message.return_value = mock_image_request_instance + self.runner.image_requests_iter = iter([("receipt_handle", mock_image_request_message)]) + + self.runner._process_image_requests() + + # Ensure request was marked as completed + mock_finish_request.assert_called_once_with("receipt_handle") + + @patch("aws.osml.model_runner.model_runner.ImageRequest") + @patch("aws.osml.model_runner.model_runner.RequestQueue.reset_request") + def test_process_image_requests_retryable(self, mock_reset_request, mock_image_request): + """Test that a RetryableJobException resets the request.""" + # Mock retryable job exception + mock_image_request_message = MagicMock() + mock_image_request_instance = MagicMock(is_valid=MagicMock(return_value=True)) + mock_image_request.from_external_message.return_value = mock_image_request_instance + self.runner.image_requests_iter = iter([("receipt_handle", mock_image_request_message)]) + self.runner.image_request_handler.process_image_request.side_effect = RetryableJobException() + + # Call method + self.runner._process_image_requests() + + # Ensure request was reset + mock_reset_request.assert_called_once_with("receipt_handle", visibility_timeout=0) + + @patch("aws.osml.model_runner.model_runner.ImageRequest") + def test_process_image_requests_general_error(self, mock_image_request): + """Test that general exceptions mark the image request as failed.""" + # Mock exception + mock_image_request_message = MagicMock() + mock_image_request_instance = MagicMock() + self.runner._fail_image_request = MagicMock() + mock_image_request.from_external_message.return_value = mock_image_request_instance + self.runner.image_requests_iter = iter([("receipt_handle", mock_image_request_message)]) + self.runner.image_request_handler.process_image_request.side_effect = Exception("Some error") + + # Call method + self.runner._process_image_requests() + + # Ensure image request was failed + self.runner._fail_image_request.assert_called() + + @patch("aws.osml.model_runner.model_runner.RegionRequestHandler.process_region_request") + @patch("aws.osml.model_runner.model_runner.RequestQueue.finish_request") + @patch("aws.osml.model_runner.model_runner.RequestQueue.reset_request") + def test_process_region_requests_general_error(self, mock_reset_request, mock_finish_request, mock_process_region): + """Test that general exceptions log an error and complete the request.""" + # Mock exception + mock_process_region.side_effect = Exception("Some region processing error") + + # Simulate queue data + self.runner.region_requests_iter = iter([("receipt_handle", {"region_id": "region_123"})]) + + # Call method + self.runner._process_region_requests() + + # Ensure the request was completed and logged + mock_finish_request.assert_called_once_with("receipt_handle") + mock_reset_request.assert_not_called() # Ensure no reset on general errors + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_api.py b/test/test_api.py index 6a28303a..3bda0b25 100755 --- a/test/test_api.py +++ b/test/test_api.py @@ -257,7 +257,7 @@ def test_image_request_invalid_sink(self): def test_image_request_invalid_image_path(self): from aws.osml.model_runner.api.exceptions import InvalidS3ObjectException - from aws.osml.model_runner.api.image_request import ImageRequest + from aws.osml.model_runner.api.request_utils import validate_image_path from aws.osml.model_runner.app_config import BotoConfig s3_client = boto3.client("s3", config=BotoConfig.default) @@ -272,7 +272,7 @@ def test_image_request_invalid_image_path(self): ) with pytest.raises(InvalidS3ObjectException): - ImageRequest.validate_image_path(TEST_S3_FULL_BUCKET_PATH, None) + validate_image_path(TEST_S3_FULL_BUCKET_PATH, None) if __name__ == "__main__": diff --git a/test/test_app.py b/test/test_end_to_end.py similarity index 99% rename from test/test_app.py rename to test/test_end_to_end.py index 3f5ddc1a..ae2faeeb 100755 --- a/test/test_app.py +++ b/test/test_end_to_end.py @@ -14,7 +14,7 @@ @mock_aws -class TestModelRunner(TestCase): +class TestModelRunnerEndToEnd(TestCase): """ Unit tests for the ModelRunner application. @@ -31,12 +31,12 @@ def setUp(self) -> None: """ from aws.osml.model_runner.api import RegionRequest from aws.osml.model_runner.api.image_request import ImageRequest - from aws.osml.model_runner.app import ModelRunner from aws.osml.model_runner.app_config import BotoConfig from aws.osml.model_runner.database.endpoint_statistics_table import EndpointStatisticsTable from aws.osml.model_runner.database.feature_table import FeatureTable from aws.osml.model_runner.database.job_table import JobTable from aws.osml.model_runner.database.region_request_table import RegionRequestTable + from aws.osml.model_runner.model_runner import ModelRunner from aws.osml.model_runner.status import ImageStatusMonitor, RegionStatusMonitor # Required to avoid warnings from GDAL