diff --git a/bin/oversightml-mr-entry-point.py b/bin/oversightml-mr-entry-point.py
index 3f924c7f..ad58fff9 100644
--- a/bin/oversightml-mr-entry-point.py
+++ b/bin/oversightml-mr-entry-point.py
@@ -11,8 +11,8 @@
from codeguru_profiler_agent import Profiler
from pythonjsonlogger import jsonlogger
-from aws.osml.model_runner.app import ModelRunner
from aws.osml.model_runner.common import ThreadingLocalContextFilter
+from aws.osml.model_runner.model_runner import ModelRunner
def handler_stop_signals(signal_num: int, frame: Optional[FrameType], model_runner: ModelRunner) -> None:
diff --git a/src/aws/osml/model_runner/api/image_request.py b/src/aws/osml/model_runner/api/image_request.py
index e1a7a2d2..6a1d4b87 100755
--- a/src/aws/osml/model_runner/api/image_request.py
+++ b/src/aws/osml/model_runner/api/image_request.py
@@ -9,7 +9,6 @@
import shapely.wkt
from shapely.geometry.base import BaseGeometry
-from aws.osml.model_runner.app_config import BotoConfig
from aws.osml.model_runner.common import (
FeatureDistillationAlgorithm,
FeatureDistillationNMS,
@@ -21,6 +20,7 @@
deserialize_post_processing_list,
get_credentials_for_assumed_role,
)
+from aws.osml.model_runner.config import BotoConfig
from .exceptions import InvalidS3ObjectException
from .inference import ModelInvokeMode
diff --git a/src/aws/osml/model_runner/common/credentials_utils.py b/src/aws/osml/model_runner/common/credentials_utils.py
index f7e09c07..95eb0050 100755
--- a/src/aws/osml/model_runner/common/credentials_utils.py
+++ b/src/aws/osml/model_runner/common/credentials_utils.py
@@ -4,7 +4,7 @@
import boto3
-from aws.osml.model_runner.app_config import BotoConfig
+from aws.osml.model_runner.config import BotoConfig
from .exceptions import InvalidAssumedRoleException
diff --git a/src/aws/osml/model_runner/common/endpoint_utils.py b/src/aws/osml/model_runner/common/endpoint_utils.py
index 40cac830..b02a159e 100755
--- a/src/aws/osml/model_runner/common/endpoint_utils.py
+++ b/src/aws/osml/model_runner/common/endpoint_utils.py
@@ -8,7 +8,7 @@
import boto3
from cachetools import TTLCache, cachedmethod
-from aws.osml.model_runner.app_config import BotoConfig, ServiceConfig
+from aws.osml.model_runner.config import BotoConfig, ServiceConfig
from .credentials_utils import get_credentials_for_assumed_role
diff --git a/src/aws/osml/model_runner/app_config.py b/src/aws/osml/model_runner/config.py
similarity index 100%
rename from src/aws/osml/model_runner/app_config.py
rename to src/aws/osml/model_runner/config.py
diff --git a/src/aws/osml/model_runner/database/ddb_helper.py b/src/aws/osml/model_runner/database/ddb_helper.py
index 34835260..670334ff 100755
--- a/src/aws/osml/model_runner/database/ddb_helper.py
+++ b/src/aws/osml/model_runner/database/ddb_helper.py
@@ -10,7 +10,7 @@
import boto3
from boto3.dynamodb.conditions import Key
-from aws.osml.model_runner.app_config import BotoConfig
+from aws.osml.model_runner.config import BotoConfig
from .exceptions import DDBBatchWriteException, DDBUpdateException
diff --git a/src/aws/osml/model_runner/database/feature_table.py b/src/aws/osml/model_runner/database/feature_table.py
index 0c88dea3..1638592b 100755
--- a/src/aws/osml/model_runner/database/feature_table.py
+++ b/src/aws/osml/model_runner/database/feature_table.py
@@ -16,8 +16,8 @@
from dacite import from_dict
from geojson import Feature
-from aws.osml.model_runner.app_config import MetricLabels, ServiceConfig
from aws.osml.model_runner.common import ImageDimensions, Timer, get_feature_image_bounds
+from aws.osml.model_runner.config import MetricLabels, ServiceConfig
from .ddb_helper import DDBHelper, DDBItem, DDBKey
from .exceptions import AddFeaturesException
diff --git a/src/aws/osml/model_runner/app.py b/src/aws/osml/model_runner/image_request_handler.py
old mode 100755
new mode 100644
similarity index 55%
rename from src/aws/osml/model_runner/app.py
rename to src/aws/osml/model_runner/image_request_handler.py
index a7da9f6d..d14172d6
--- a/src/aws/osml/model_runner/app.py
+++ b/src/aws/osml/model_runner/image_request_handler.py
@@ -11,18 +11,16 @@
from typing import Any, Dict, List, Optional, Tuple
import shapely.geometry.base
-from aws_embedded_metrics.logger.metrics_logger import MetricsLogger
-from aws_embedded_metrics.metric_scope import metric_scope
+from aws_embedded_metrics import MetricsLogger, metric_scope
from aws_embedded_metrics.unit import Unit
from geojson import Feature
from osgeo import gdal
from osgeo.gdal import Dataset
-from aws.osml.gdal import GDALConfigEnv, get_image_extension, load_gdal_dataset, set_gdal_default_configuration
+from aws.osml.gdal import GDALConfigEnv, get_image_extension, load_gdal_dataset
from aws.osml.photogrammetry import ImageCoordinate, SensorModel
from .api import VALID_MODEL_HOSTING_OPTIONS, ImageRequest, InvalidImageRequestException, RegionRequest, SinkMode
-from .app_config import MetricLabels, ServiceConfig
from .common import (
EndpointUtils,
FeatureDistillationDeserializer,
@@ -30,11 +28,11 @@
ImageDimensions,
ImageRegion,
RequestStatus,
- ThreadingLocalContextFilter,
Timer,
get_credentials_for_assumed_role,
mr_post_processing_options_factory,
)
+from .config import MetricLabels, ServiceConfig
from .database import EndpointStatisticsTable, FeatureTable, JobItem, JobTable, RegionRequestItem, RegionRequestTable
from .exceptions import (
AggregateFeaturesException,
@@ -43,16 +41,14 @@
InvalidImageURLException,
LoadImageException,
ProcessImageException,
- ProcessRegionException,
- RetryableJobException,
- SelfThrottledRegionException,
UnsupportedModelException,
)
from .inference import FeatureSelector, calculate_processing_bounds, get_source_property
from .queue import RequestQueue
+from .region_request_handler import RegionRequestHandler
from .sink import SinkFactory
-from .status import ImageStatusMonitor, RegionStatusMonitor
-from .tile_worker import TilingStrategy, VariableOverlapTilingStrategy, process_tiles, setup_tile_workers
+from .status import ImageStatusMonitor
+from .tile_worker import TilingStrategy
# Set up logging configuration
logger = logging.getLogger(__name__)
@@ -62,177 +58,49 @@
gdal.UseExceptions()
-class ModelRunner:
+class ImageRequestHandler:
"""
- Main class for operating the ModelRunner application. It monitors input queues for processing requests,
- decomposes the image into a set of smaller regions and tiles, invokes an ML model endpoint with each tile, and
- finally aggregates all the results into a single output which can be deposited into the desired output sinks.
+ Class responsible for handling ImageRequest processing.
"""
- def __init__(self, tiling_strategy: TilingStrategy = VariableOverlapTilingStrategy()) -> None:
- """
- Initialize a model runner with the injectable behaviors.
-
- :param tiling_strategy: class defining how a larger image will be broken into chunks for processing
- """
- self.config = ServiceConfig()
- self.tiling_strategy = tiling_strategy
- self.image_request_queue = RequestQueue(self.config.image_queue, wait_seconds=0)
- self.image_requests_iter = iter(self.image_request_queue)
- self.job_table = JobTable(self.config.job_table)
- self.region_request_table = RegionRequestTable(self.config.region_request_table)
- self.endpoint_statistics_table = EndpointStatisticsTable(self.config.endpoint_statistics_table)
- self.region_request_queue = RequestQueue(self.config.region_queue, wait_seconds=10)
- self.region_requests_iter = iter(self.region_request_queue)
- self.image_status_monitor = ImageStatusMonitor(self.config.image_status_topic)
- self.region_status_monitor = RegionStatusMonitor(self.config.region_status_topic)
- self.endpoint_utils = EndpointUtils()
- self.running = False
-
- def run(self) -> None:
- """
- Starts ModelRunner in a loop that continuously monitors the image work queue and region work queue.
-
- :return: None
- """
- self.monitor_work_queues()
-
- def stop(self) -> None:
- """
- Stops ModelRunner by setting the global run variable to False.
-
- :return: None
+ def __init__(
+ self,
+ job_table: JobTable,
+ image_status_monitor: ImageStatusMonitor,
+ endpoint_statistics_table: EndpointStatisticsTable,
+ tiling_strategy: TilingStrategy,
+ region_request_queue: RequestQueue,
+ region_request_table: RegionRequestTable,
+ endpoint_utils: EndpointUtils,
+ config: ServiceConfig,
+ region_request_handler: RegionRequestHandler,
+ ) -> None:
"""
- self.running = False
+ Initialize the ImageRequestHandler with the necessary dependencies.
- def monitor_work_queues(self) -> None:
+ :param job_table: The job table for image processing.
+ :param image_status_monitor: A monitor to track image request status.
+ :param endpoint_statistics_table: Table for tracking endpoint statistics.
+ :param tiling_strategy: The strategy for handling image tiling.
+ :param region_request_queue: Queue to send region requests.
+ :param region_request_table: Table to track region request progress.
+ :param endpoint_utils: Utility class for handling endpoint-related operations.
+ :param config: Configuration settings for the service.
"""
- Monitors SQS queues for ImageRequest and RegionRequest The region work queue is checked first and will wait for
- up to 10 seconds to start work. Only if no regions need to be processed in that time will this worker check to
- see if a new image can be started. Ultimately this setup is intended to ensure that all the regions for an image
- are completed by the cluster before work begins on more images.
- :return: None
- """
- # Set the running state to True
- self.running = True
-
- # Set up the GDAL configuration options that should remain unchanged for the life of this execution
- set_gdal_default_configuration()
- try:
- while self.running:
- logger.debug("Checking work queue for regions to process ...")
- (receipt_handle, region_request_attributes) = next(self.region_requests_iter)
- ThreadingLocalContextFilter.set_context(region_request_attributes)
-
- # If we found a region request on the queue
- if region_request_attributes is not None:
- try:
- # Parse the message into a working RegionRequest
- region_request = RegionRequest(region_request_attributes)
-
- # If the image request has a s3 url lets augment its path for virtual hosting
- if "s3:/" in region_request.image_url:
- # Validate that image exists in S3
- ImageRequest.validate_image_path(region_request.image_url, region_request.image_read_role)
- image_path = region_request.image_url.replace("s3:/", "/vsis3", 1)
- else:
- image_path = region_request.image_url
-
- # Load the image into a GDAL dataset
- raster_dataset, sensor_model = load_gdal_dataset(image_path)
- image_format = str(raster_dataset.GetDriver().ShortName).upper()
-
- # Get RegionRequestItem if not create new RegionRequestItem
- region_request_item = self.region_request_table.get_region_request(
- region_request.region_id, region_request.image_id
- )
- if region_request_item is None:
- # Create a new item from the region request
- region_request_item = RegionRequestItem.from_region_request(region_request)
-
- # Add the item to the table and start it processing
- self.region_request_table.start_region_request(region_request_item)
- logging.debug(
- (
- f"Adding region request: image id: {region_request_item.image_id} - "
- f"region id: {region_request_item.region_id}"
- )
- )
-
- # Process our region request
- image_request_item = self.process_region_request(
- region_request, region_request_item, raster_dataset, sensor_model
- )
-
- # Check if the image is complete
- if self.job_table.is_image_request_complete(image_request_item):
- # If so complete the image request
- self.complete_image_request(region_request, image_format, raster_dataset, sensor_model)
-
- # Update the queue
- self.region_request_queue.finish_request(receipt_handle)
- except RetryableJobException:
- self.region_request_queue.reset_request(receipt_handle, visibility_timeout=0)
- except SelfThrottledRegionException:
- self.region_request_queue.reset_request(
- receipt_handle,
- visibility_timeout=int(self.config.throttling_retry_timeout),
- )
- except Exception as err:
- logger.error(f"There was a problem processing the region request: {err}")
- self.region_request_queue.finish_request(receipt_handle)
- else:
- logger.debug("Checking work queue for images to process ...")
- (receipt_handle, image_request_message) = next(self.image_requests_iter)
-
- # If we found a request on the queue
- if image_request_message is not None:
- image_request = None
- try:
- # Parse the message into a working ImageRequest
- image_request = ImageRequest.from_external_message(image_request_message)
- ThreadingLocalContextFilter.set_context(image_request.__dict__)
-
- # Check that our image request looks good
- if not image_request.is_valid():
- error = f"Invalid image request: {image_request_message}"
- logger.exception(error)
- raise InvalidImageRequestException(error)
-
- # Process the request
- self.process_image_request(image_request)
-
- # Update the queue
- self.image_request_queue.finish_request(receipt_handle)
- except RetryableJobException:
- self.image_request_queue.reset_request(receipt_handle, visibility_timeout=0)
- except Exception as err:
- logger.error(f"There was a problem processing the image request: {err}")
- min_image_id = image_request.image_id if image_request and image_request.image_id else ""
- min_job_id = image_request.job_id if image_request and image_request.job_id else ""
- minimal_job_item = JobItem(
- image_id=min_image_id,
- job_id=min_job_id,
- processing_duration=0,
- )
- self.fail_image_request_send_messages(minimal_job_item, err)
- self.image_request_queue.finish_request(receipt_handle)
- finally:
- # If we stop monitoring the queue set run state to false
- self.running = False
+ self.job_table = job_table
+ self.image_status_monitor = image_status_monitor
+ self.endpoint_statistics_table = endpoint_statistics_table
+ self.tiling_strategy = tiling_strategy
+ self.region_request_queue = region_request_queue
+ self.region_request_table = region_request_table
+ self.endpoint_utils = endpoint_utils
+ self.config = config
+ self.region_request_handler = region_request_handler
def process_image_request(self, image_request: ImageRequest) -> None:
"""
- Processes ImageRequest objects that are picked up from queue. Loads the specified image into memory to be
- chipped apart into regions and sent downstream for processing via RegionRequest. This will also process the
- first region chipped from the image. # This worker will process the first region of this image since it has
- already loaded the dataset from S3 and is ready to go. Any additional regions will be queued for processing by
- other workers in this cluster.
-
- :param image_request: ImageRequest = the image request derived from the ImageRequest SQS message
-
- :return: None
+ Processes ImageRequest objects picked up from the queue.
"""
image_request_item = None
try:
@@ -243,7 +111,6 @@ def process_image_request(self, image_request: ImageRequest) -> None:
# Add entry to the endpoint statistics table
self.endpoint_statistics_table.upsert_endpoint(image_request.model_name, max_regions)
- # Update the image status to started and include relevant image meta-data
logger.debug(f"Starting processing of {image_request.image_url}")
image_request_item = JobItem(
image_id=image_request.image_id,
@@ -264,14 +131,11 @@ def process_image_request(self, image_request: ImageRequest) -> None:
asdict(feature_distillation_option_list[0], dict_factory=mr_post_processing_options_factory)
)
- # Start the image processing
self.job_table.start_image_request(image_request_item)
self.image_status_monitor.process_event(image_request_item, RequestStatus.STARTED, "Started image request")
- # Check we have a valid image request, throws if not
self.validate_model_hosting(image_request_item)
- # Load the relevant image meta data into memory
image_extension, raster_dataset, sensor_model, all_regions = self.load_image_request(
image_request_item, image_request.roi
)
@@ -281,38 +145,28 @@ def process_image_request(self, image_request: ImageRequest) -> None:
f"Dataset {image_request_item.image_id} has no geo transform. Results are not geo-referenced."
)
- # If we got valid outputs
if raster_dataset and all_regions and image_extension:
image_request_item.region_count = len(all_regions)
image_request_item.width = int(raster_dataset.RasterXSize)
image_request_item.height = int(raster_dataset.RasterYSize)
try:
- image_request_item.extents = json.dumps(ModelRunner.get_extents(raster_dataset, sensor_model))
+ image_request_item.extents = json.dumps(self.get_extents(raster_dataset, sensor_model))
except Exception as e:
logger.warning(f"Could not get extents for image: {image_request_item.image_id}")
logger.exception(e)
feature_properties: List[dict] = json.loads(image_request_item.feature_properties)
-
- # If we can get a valid source metadata from the source image - attach it to features
- # else, just pass in whatever custom features if they were provided
source_metadata = get_source_property(image_request_item.image_url, image_extension, raster_dataset)
if isinstance(source_metadata, dict):
feature_properties.append(source_metadata)
- # Update the feature properties
image_request_item.feature_properties = json.dumps(feature_properties)
-
- # Update the image request job to have new derived image data
self.job_table.update_image_request(image_request_item)
-
self.image_status_monitor.process_event(image_request_item, RequestStatus.IN_PROGRESS, "Processing regions")
- # Place the resulting region requests on the appropriate work queue
self.queue_region_request(all_regions, image_request, raster_dataset, sensor_model, image_extension)
except Exception as err:
- # We failed try and gracefully update our image request
if image_request_item:
self.fail_image_request(image_request_item, err)
else:
@@ -322,186 +176,49 @@ def process_image_request(self, image_request: ImageRequest) -> None:
processing_duration=0,
)
self.fail_image_request(minimal_job_item, err)
-
- # Let the application know that we failed to process image
raise ProcessImageException("Failed to process image region!") from err
- def queue_region_request(
- self,
- all_regions: List[ImageRegion],
- image_request: ImageRequest,
- raster_dataset: Dataset,
- sensor_model: Optional[SensorModel],
- image_extension: Optional[str],
+ def complete_image_request(
+ self, region_request: RegionRequest, image_format: str, raster_dataset: gdal.Dataset, sensor_model: SensorModel
) -> None:
"""
- Loads the list of regions into the queue. First it will create a RequestRequestItem and creates
- an entry into the RegionRequestTable for traceability. Then process the region request. Once it's completed,
- it will update an entry in the RegionRequestTable.
-
- :param image_extension: = the GDAL derived image extension
- :param all_regions: List[ImageRegion] = the list of image regions
- :param image_request: ImageRequest = the image request
- :param raster_dataset: Dataset = the raster dataset containing the region
- :param sensor_model: Optional[SensorModel] = the sensor model for this raster dataset
-
- :return: None
- """
- # Set aside the first region
- first_region = all_regions.pop(0)
- for region in all_regions:
- logger.debug(f"Queueing region: {region}")
-
- region_request = RegionRequest(
- image_request.get_shared_values(),
- region_bounds=region,
- region_id=f"{region[0]}{region[1]}-{image_request.job_id}",
- image_extension=image_extension,
- )
-
- # Create a new entry to the region request being started
- region_request_item = RegionRequestItem.from_region_request(region_request)
- self.region_request_table.start_region_request(region_request_item)
- logging.debug(
- (
- f"Adding region request: image id: {region_request_item.image_id} - "
- f"region id: {region_request_item.region_id}"
- )
- )
-
- # Send the attributes of the region request as the message.
- self.region_request_queue.send_request(region_request.__dict__)
-
- # Go ahead and process the first region
- logger.debug(f"Processing first region {0}: {first_region}")
-
- first_region_request = RegionRequest(
- image_request.get_shared_values(),
- region_bounds=first_region,
- region_id=f"{first_region[0]}{first_region[1]}-{image_request.job_id}",
- image_extension=image_extension,
- )
-
- # Add item to RegionRequestTable
- first_region_request_item = RegionRequestItem.from_region_request(first_region_request)
- self.region_request_table.start_region_request(first_region_request_item)
- logging.debug(f"Adding region_id: {first_region_request_item.region_id}")
-
- # Processes our region request and return the updated item
- image_request_item = self.process_region_request(
- first_region_request, first_region_request_item, raster_dataset, sensor_model
- )
-
- # If the image is finished then complete it
- if self.job_table.is_image_request_complete(image_request_item):
- image_format = str(raster_dataset.GetDriver().ShortName).upper()
- self.complete_image_request(first_region_request, image_format, raster_dataset, sensor_model)
-
- @metric_scope
- def process_region_request(
- self,
- region_request: RegionRequest,
- region_request_item: RegionRequestItem,
- raster_dataset: gdal.Dataset,
- sensor_model: Optional[SensorModel] = None,
- metrics: MetricsLogger = None,
- ) -> JobItem:
- """
- Processes RegionRequest objects that are delegated for processing. Loads the specified region of an image into
- memory to be processed by tile-workers. If a raster_dataset is not provided directly it will poll the image
- from the region request.
-
- :param region_request: RegionRequest = the region request
- :param region_request_item: RegionRequestItem = the region request to update
- :param raster_dataset: gdal.Dataset = the raster dataset containing the region
- :param sensor_model: Optional[SensorModel] = the sensor model for this raster dataset
- :param metrics: MetricsLogger = the metrics logger to use to report metrics.
-
- :return: None
+ Runs after every region has completed processing to check if that was the last region and run required
+ completion logic for the associated ImageRequest.
"""
- if isinstance(metrics, MetricsLogger):
- metrics.set_dimensions()
-
- if not region_request.is_valid():
- logger.error(f"Invalid Region Request! {region_request.__dict__}")
- raise ValueError("Invalid Region Request")
-
- if isinstance(metrics, MetricsLogger):
- image_format = str(raster_dataset.GetDriver().ShortName).upper()
- metrics.put_dimensions(
- {
- MetricLabels.OPERATION_DIMENSION: MetricLabels.REGION_PROCESSING_OPERATION,
- MetricLabels.MODEL_NAME_DIMENSION: region_request.model_name,
- MetricLabels.INPUT_FORMAT_DIMENSION: image_format,
- }
- )
-
- if self.config.self_throttling:
- max_regions = self.endpoint_utils.calculate_max_regions(
- region_request.model_name, region_request.model_invocation_role
- )
- # Add entry to the endpoint statistics table
- self.endpoint_statistics_table.upsert_endpoint(region_request.model_name, max_regions)
- in_progress = self.endpoint_statistics_table.current_in_progress_regions(region_request.model_name)
-
- if in_progress >= max_regions:
- if isinstance(metrics, MetricsLogger):
- metrics.put_metric(MetricLabels.THROTTLES, 1, str(Unit.COUNT.value))
- logger.warning(f"Throttling region request. (Max: {max_regions} In-progress: {in_progress}")
- raise SelfThrottledRegionException
-
- # Increment the endpoint region counter
- self.endpoint_statistics_table.increment_region_count(region_request.model_name)
-
try:
- with Timer(
- task_str=f"Processing region {region_request.image_url} {region_request.region_bounds}",
- metric_name=MetricLabels.DURATION,
- logger=logger,
- metrics_logger=metrics,
- ):
- # Set up our threaded tile worker pool
- tile_queue, tile_workers = setup_tile_workers(region_request, sensor_model, self.config.elevation_model)
-
- # Process all our tiles
- total_tile_count, failed_tile_count = process_tiles(
- self.tiling_strategy, region_request_item, tile_queue, tile_workers, raster_dataset, sensor_model
- )
+ image_request_item = self.job_table.get_image_request(region_request.image_id)
- # Update table w/ total tile counts
- region_request_item.total_tiles = total_tile_count
- region_request_item.succeeded_tile_count = total_tile_count - failed_tile_count
- region_request_item.failed_tile_count = failed_tile_count
- region_request_item = self.region_request_table.update_region_request(region_request_item)
+ logger.debug("Last region of image request was completed, aggregating features for image!")
- # Update the image request to complete this region
- image_request_item = self.job_table.complete_region_request(region_request.image_id, bool(failed_tile_count))
+ roi = None
+ if image_request_item.roi_wkt:
+ logger.debug(f"Using ROI from request to set processing boundary: {image_request_item.roi_wkt}")
+ roi = shapely.wkt.loads(image_request_item.roi_wkt)
+ processing_bounds = calculate_processing_bounds(raster_dataset, roi, sensor_model)
+ logger.debug(f"Processing boundary from {roi} is {processing_bounds}")
- # Update region request table if that region succeeded
- region_status = self.region_status_monitor.get_status(region_request_item)
- region_request_item = self.region_request_table.complete_region_request(region_request_item, region_status)
+ feature_table = FeatureTable(self.config.feature_table, region_request.tile_size, region_request.tile_overlap)
+ features = feature_table.aggregate_features(image_request_item)
+ features = self.select_features(image_request_item, features, processing_bounds)
+ features = self.add_properties_to_features(image_request_item, features)
- self.region_status_monitor.process_event(region_request_item, region_status, "Completed region processing")
+ is_write_succeeded = self.sink_features(image_request_item, features)
+ if not is_write_succeeded:
+ raise AggregateOutputFeaturesException("Failed to write features to S3 or Kinesis! Please check the log...")
- # Write CloudWatch Metrics to the Logs
- if isinstance(metrics, MetricsLogger):
- # TODO: Consider adding the +1 invocation to timer
- metrics.put_metric(MetricLabels.INVOCATIONS, 1, str(Unit.COUNT.value))
+ completed_image_request_item = self.job_table.end_image_request(image_request_item.image_id)
- # Return the updated item
- return image_request_item
+ if completed_image_request_item.processing_duration is not None:
+ image_request_status = self.image_status_monitor.get_status(completed_image_request_item)
+ self.image_status_monitor.process_event(
+ completed_image_request_item, image_request_status, "Completed image processing"
+ )
+ self.generate_image_processing_metrics(completed_image_request_item, image_format)
+ else:
+ raise InvalidImageRequestException("ImageRequest has no start time")
except Exception as err:
- failed_msg = f"Failed to process image region: {err}"
- logger.error(failed_msg)
- # update the table to take in that exception
- region_request_item.message = failed_msg
- return self.fail_region_request(region_request_item)
-
- finally:
- # Decrement the endpoint region counter
- if self.config.self_throttling:
- self.endpoint_statistics_table.decrement_region_count(region_request.model_name)
+ raise AggregateFeaturesException("Failed to aggregate features for region!") from err
def load_image_request(
self,
@@ -570,85 +287,65 @@ def load_image_request(
def fail_image_request(self, image_request_item: JobItem, err: Exception) -> None:
"""
- Handles failure events/exceptions for image requests and tries to update the status monitor accordingly
+ Handles failure events/exceptions for image requests and tries to update the status monitor accordingly.
:param image_request_item: JobItem = the image request that failed.
:param err: Exception = the exception that caused the failure
-
:return: None
"""
- self.fail_image_request_send_messages(image_request_item, err)
+ self.fail_image_request(image_request_item, err)
self.job_table.end_image_request(image_request_item.image_id)
- def fail_image_request_send_messages(self, image_request_item: JobItem, err: Exception) -> None:
- """
- Updates failed metrics and update the status monitor accordingly
-
- :param image_request_item: JobItem = the image request that failed.
- :param err: Exception = the exception that caused the failure
-
- :return: None
- """
- logger.exception(f"Failed to start image processing!: {err}")
- self.image_status_monitor.process_event(image_request_item, RequestStatus.FAILED, str(err))
-
- def complete_image_request(
- self, region_request: RegionRequest, image_format: str, raster_dataset: gdal.Dataset, sensor_model: SensorModel
+ def queue_region_request(
+ self,
+ all_regions: List[ImageRegion],
+ image_request: ImageRequest,
+ raster_dataset: Dataset,
+ sensor_model: Optional[SensorModel],
+ image_extension: Optional[str],
) -> None:
"""
- Runs after every region has completed processing to check if that was the last region and run required
- completion logic for the associated ImageRequest.
+ Queues region requests and handles processing of the first region.
+ """
+ first_region = all_regions.pop(0)
+ for region in all_regions:
+ logger.debug(f"Queueing region: {region}")
- :param region_request: RegionRequest = the region request to update.
- :param image_format: Format of the image data
- :param raster_dataset: the image data rater
- :param sensor_model: the image sensor model
+ region_request = RegionRequest(
+ image_request.get_shared_values(),
+ region_bounds=region,
+ region_id=f"{region[0]}{region[1]}-{image_request.job_id}",
+ image_extension=image_extension,
+ )
- :return: None
- """
- try:
- # Grab the full image request item from the table
- image_request_item = self.job_table.get_image_request(region_request.image_id)
+ region_request_item = RegionRequestItem.from_region_request(region_request)
+ self.region_request_table.start_region_request(region_request_item)
- logger.debug("Last region of image request was completed, aggregating features for image!")
+ logger.debug(
+ f"Adding region request: image id: {region_request_item.image_id} - "
+ f"region id: {region_request_item.region_id}"
+ )
- roi = None
- if image_request_item.roi_wkt:
- logger.debug(f"Using ROI from request to set processing boundary: {image_request_item.roi_wkt}")
- roi = shapely.wkt.loads(image_request_item.roi_wkt)
- processing_bounds = calculate_processing_bounds(raster_dataset, roi, sensor_model)
- logger.debug(f"Processing boundary from {roi} is {processing_bounds}")
+ self.region_request_queue.send_request(region_request.__dict__)
- # Set up our feature table to work with the region quest
- feature_table = FeatureTable(self.config.feature_table, region_request.tile_size, region_request.tile_overlap)
- # Aggregate all the features from our job
- features = feature_table.aggregate_features(image_request_item)
- features = self.select_features(image_request_item, features, processing_bounds)
- features = self.add_properties_to_features(image_request_item, features)
+ logger.debug(f"Processing first region {first_region}: {first_region}")
- # Sink the features into the right outputs
- is_write_succeeded = self.sink_features(image_request_item, features)
- if not is_write_succeeded:
- raise AggregateOutputFeaturesException(
- "Failed to write features to S3 or Kinesis! Please check the " "log..."
- )
+ first_region_request = RegionRequest(
+ image_request.get_shared_values(),
+ region_bounds=first_region,
+ region_id=f"{first_region[0]}{first_region[1]}-{image_request.job_id}",
+ image_extension=image_extension,
+ )
- # Put our end time on our image_request_item
- completed_image_request_item = self.job_table.end_image_request(image_request_item.image_id)
+ first_region_request_item = RegionRequestItem.from_region_request(first_region_request)
- # Ensure we have a valid start time for our record
- # TODO: Figure out why we wouldn't have a valid start time?!?!
- if completed_image_request_item.processing_duration is not None:
- image_request_status = self.image_status_monitor.get_status(completed_image_request_item)
- self.image_status_monitor.process_event(
- completed_image_request_item, image_request_status, "Completed image processing"
- )
- self.generate_image_processing_metrics(completed_image_request_item, image_format)
- else:
- raise InvalidImageRequestException("ImageRequest has no start time")
+ image_request_item = self.region_request_handler.process_region_request(
+ first_region_request, first_region_request_item, raster_dataset, sensor_model
+ )
- except Exception as err:
- raise AggregateFeaturesException("Failed to aggregate features for region!") from err
+ if self.job_table.is_image_request_complete(image_request_item):
+ image_format = str(raster_dataset.GetDriver().ShortName).upper()
+ self.complete_image_request(first_region_request, image_format, raster_dataset, sensor_model)
@metric_scope
def generate_image_processing_metrics(
@@ -680,49 +377,6 @@ def generate_image_processing_metrics(
if image_request_item.region_error > 0:
metrics.put_metric(MetricLabels.ERRORS, 1, str(Unit.COUNT.value))
- def fail_region_request(
- self,
- region_request_item: RegionRequestItem,
- metrics: MetricsLogger = None,
- ) -> JobItem:
- """
- Fails a region if it failed to process successfully and updates the table accordingly before
- raising an exception
-
- :param region_request_item: RegionRequestItem = the region request to update
- :param metrics: MetricsLogger = the metrics logger to use to report metrics.
-
- :return: None
- """
- if isinstance(metrics, MetricsLogger):
- metrics.put_metric(MetricLabels.ERRORS, 1, str(Unit.COUNT.value))
- try:
- region_status = RequestStatus.FAILED
- region_request_item = self.region_request_table.complete_region_request(region_request_item, region_status)
- self.region_status_monitor.process_event(region_request_item, region_status, "Completed region processing")
- return self.job_table.complete_region_request(region_request_item.image_id, error=True)
- except Exception as status_error:
- logger.error("Unable to update region status in job table")
- logger.exception(status_error)
- raise ProcessRegionException("Failed to process image region!")
-
- def validate_model_hosting(self, image_request: JobItem):
- """
- Validates that the image request is valid. If not, raises an exception.
-
- :param image_request: JobItem = the image request
-
- :return: None
- """
- if not image_request.model_invoke_mode or image_request.model_invoke_mode not in VALID_MODEL_HOSTING_OPTIONS:
- error = f"Application only supports ${VALID_MODEL_HOSTING_OPTIONS} Endpoints"
- self.image_status_monitor.process_event(
- image_request,
- RequestStatus.FAILED,
- error,
- )
- raise UnsupportedModelException(error)
-
@metric_scope
def select_features(
self,
@@ -824,16 +478,16 @@ def sink_features(image_request_item: JobItem, features: List[Feature], metrics:
# Log them let them know if both written to both outputs (S3 and Kinesis) or one in another
# If both couldn't write to either stream because both were down, return False. Otherwise True
if tracking_output_sinks["S3"] and not tracking_output_sinks["Kinesis"]:
- logging.debug("OSMLModelRunner was able to write the features to S3 but not Kinesis. Continuing...")
+ logging.debug("LModelRunner was able to write the features to S3 but not Kinesis. Continuing...")
return True
elif not tracking_output_sinks["S3"] and tracking_output_sinks["Kinesis"]:
- logging.debug("OSMLModelRunner was able to write the features to Kinesis but not S3. Continuing...")
+ logging.debug("ModelRunner was able to write the features to Kinesis but not S3. Continuing...")
return True
elif tracking_output_sinks["S3"] and tracking_output_sinks["Kinesis"]:
- logging.debug("OSMLModelRunner was able to write the features to both S3 and Kinesis. Continuing...")
+ logging.debug("ModelRunner was able to write the features to both S3 and Kinesis. Continuing...")
return True
else:
- logging.error("OSMLModelRunner was not able to write the features to either S3 or Kinesis. Failing...")
+ logging.error("ModelRunner was not able to write the features to either S3 or Kinesis. Failing...")
return False
else:
raise InvalidImageRequestException("No output destinations were defined for this image request!")
@@ -898,6 +552,23 @@ def get_inference_metadata_property(image_request_item: JobItem, inference_time:
}
return inference_metadata_property
+ def validate_model_hosting(self, image_request: JobItem):
+ """
+ Validates that the image request is valid. If not, raises an exception.
+
+ :param image_request: JobItem = the image request
+
+ :return: None
+ """
+ if not image_request.model_invoke_mode or image_request.model_invoke_mode not in VALID_MODEL_HOSTING_OPTIONS:
+ error = f"Application only supports ${VALID_MODEL_HOSTING_OPTIONS} Endpoints"
+ self.image_status_monitor.process_event(
+ image_request,
+ RequestStatus.FAILED,
+ error,
+ )
+ raise UnsupportedModelException(error)
+
@staticmethod
def get_extents(ds: gdal.Dataset, sm: SensorModel) -> Dict[str, Any]:
"""
diff --git a/src/aws/osml/model_runner/inference/http_detector.py b/src/aws/osml/model_runner/inference/http_detector.py
index 2bf46250..b922a9ae 100644
--- a/src/aws/osml/model_runner/inference/http_detector.py
+++ b/src/aws/osml/model_runner/inference/http_detector.py
@@ -16,8 +16,8 @@
from urllib3.util.retry import Retry
from aws.osml.model_runner.api import ModelInvokeMode
-from aws.osml.model_runner.app_config import MetricLabels
from aws.osml.model_runner.common import Timer
+from aws.osml.model_runner.config import MetricLabels
from .detector import Detector
from .endpoint_builder import FeatureEndpointBuilder
diff --git a/src/aws/osml/model_runner/inference/sm_detector.py b/src/aws/osml/model_runner/inference/sm_detector.py
index 093518b0..ca840123 100644
--- a/src/aws/osml/model_runner/inference/sm_detector.py
+++ b/src/aws/osml/model_runner/inference/sm_detector.py
@@ -14,8 +14,8 @@
from geojson import FeatureCollection
from aws.osml.model_runner.api import ModelInvokeMode
-from aws.osml.model_runner.app_config import BotoConfig, MetricLabels
from aws.osml.model_runner.common import Timer
+from aws.osml.model_runner.config import BotoConfig, MetricLabels
from .detector import Detector
from .endpoint_builder import FeatureEndpointBuilder
diff --git a/src/aws/osml/model_runner/model_runner.py b/src/aws/osml/model_runner/model_runner.py
new file mode 100755
index 00000000..7461c191
--- /dev/null
+++ b/src/aws/osml/model_runner/model_runner.py
@@ -0,0 +1,174 @@
+# Copyright 2023-2024 Amazon.com, Inc. or its affiliates.
+
+import logging
+
+from osgeo import gdal
+
+from aws.osml.gdal import load_gdal_dataset, set_gdal_default_configuration
+
+from .api import ImageRequest, InvalidImageRequestException, RegionRequest
+from .common import EndpointUtils, ThreadingLocalContextFilter
+from .config import ServiceConfig
+from .database import EndpointStatisticsTable, JobItem, JobTable, RegionRequestItem, RegionRequestTable
+from .exceptions import RetryableJobException, SelfThrottledRegionException
+from .image_request_handler import ImageRequestHandler
+from .queue import RequestQueue
+from .region_request_handler import RegionRequestHandler
+from .status import ImageStatusMonitor, RegionStatusMonitor
+from .tile_worker import TilingStrategy, VariableOverlapTilingStrategy
+
+# Set up logging configuration
+logger = logging.getLogger(__name__)
+
+# GDAL 4.0 will begin using exceptions as the default; at this point the software is written to assume
+# no exceptions so we call this explicitly until the software can be updated to match.
+gdal.UseExceptions()
+
+
+class ModelRunner:
+ """
+ Main class for operating the ModelRunner application. It monitors input queues for processing requests,
+ decomposes the image into a set of smaller regions and tiles, invokes an ML model endpoint with each tile, and
+ finally aggregates all the results into a single output which can be deposited into the desired output sinks.
+ """
+
+ def __init__(self, tiling_strategy: TilingStrategy = VariableOverlapTilingStrategy()) -> None:
+ """
+ Initialize a model runner with the injectable behaviors.
+
+ :param tiling_strategy: class defining how a larger image will be broken into chunks for processing
+ """
+ self.config = ServiceConfig()
+ self.tiling_strategy = tiling_strategy
+ self.image_request_queue = RequestQueue(self.config.image_queue, wait_seconds=0)
+ self.image_requests_iter = iter(self.image_request_queue)
+ self.job_table = JobTable(self.config.job_table)
+ self.region_request_table = RegionRequestTable(self.config.region_request_table)
+ self.endpoint_statistics_table = EndpointStatisticsTable(self.config.endpoint_statistics_table)
+ self.region_request_queue = RequestQueue(self.config.region_queue, wait_seconds=10)
+ self.region_requests_iter = iter(self.region_request_queue)
+ self.image_status_monitor = ImageStatusMonitor(self.config.image_status_topic)
+ self.region_status_monitor = RegionStatusMonitor(self.config.region_status_topic)
+ self.endpoint_utils = EndpointUtils()
+ self.running = False
+
+ # Pass dependencies into RegionRequestHandler
+ self.region_request_handler = RegionRequestHandler(
+ region_request_table=self.region_request_table,
+ job_table=self.job_table,
+ region_status_monitor=self.region_status_monitor,
+ endpoint_statistics_table=self.endpoint_statistics_table,
+ tiling_strategy=self.tiling_strategy,
+ region_request_queue=self.region_request_queue,
+ endpoint_utils=self.endpoint_utils,
+ config=self.config,
+ )
+
+ # Pass dependencies into ImageRequestHandler
+ self.image_request_handler = ImageRequestHandler(
+ job_table=self.job_table,
+ image_status_monitor=self.image_status_monitor,
+ endpoint_statistics_table=self.endpoint_statistics_table,
+ tiling_strategy=self.tiling_strategy,
+ region_request_queue=self.region_request_queue,
+ region_request_table=self.region_request_table,
+ endpoint_utils=self.endpoint_utils,
+ config=self.config,
+ region_request_handler=self.region_request_handler,
+ )
+
+ def run(self) -> None:
+ """
+ Starts ModelRunner in a loop that continuously monitors the image work queue and region work queue.
+ """
+ self.monitor_work_queues()
+
+ def stop(self) -> None:
+ """
+ Stops ModelRunner by setting the global run variable to False.
+ """
+ self.running = False
+
+ def monitor_work_queues(self) -> None:
+ """
+ Monitors SQS queues for ImageRequest and RegionRequest.
+ """
+ self.running = True
+ set_gdal_default_configuration()
+
+ try:
+ while self.running:
+ logger.debug("Checking work queue for regions to process ...")
+ (receipt_handle, region_request_attributes) = next(self.region_requests_iter)
+ ThreadingLocalContextFilter.set_context(region_request_attributes)
+
+ if region_request_attributes is not None:
+ try:
+ region_request = RegionRequest(region_request_attributes)
+
+ if "s3:/" in region_request.image_url:
+ ImageRequest.validate_image_path(region_request.image_url, region_request.image_read_role)
+ image_path = region_request.image_url.replace("s3:/", "/vsis3", 1)
+ else:
+ image_path = region_request.image_url
+
+ raster_dataset, sensor_model = load_gdal_dataset(image_path)
+ image_format = str(raster_dataset.GetDriver().ShortName).upper()
+
+ region_request_item = self.region_request_table.get_region_request(
+ region_request.region_id, region_request.image_id
+ )
+ if region_request_item is None:
+ region_request_item = RegionRequestItem.from_region_request(region_request)
+
+ image_request_item = self.region_request_handler.process_region_request(
+ region_request, region_request_item, raster_dataset, sensor_model
+ )
+
+ if self.job_table.is_image_request_complete(image_request_item):
+ self.image_request_handler.complete_image_request(
+ region_request, image_format, raster_dataset, sensor_model
+ )
+
+ self.region_request_queue.finish_request(receipt_handle)
+ except RetryableJobException:
+ self.region_request_queue.reset_request(receipt_handle, visibility_timeout=0)
+ except SelfThrottledRegionException:
+ self.region_request_queue.reset_request(
+ receipt_handle, visibility_timeout=int(self.config.throttling_retry_timeout)
+ )
+ except Exception as err:
+ logger.error(f"There was a problem processing the region request: {err}")
+ self.region_request_queue.finish_request(receipt_handle)
+ else:
+ logger.debug("Checking work queue for images to process ...")
+ (receipt_handle, image_request_message) = next(self.image_requests_iter)
+
+ if image_request_message is not None:
+ image_request = None
+ try:
+ image_request = ImageRequest.from_external_message(image_request_message)
+ ThreadingLocalContextFilter.set_context(image_request.__dict__)
+
+ if not image_request.is_valid():
+ error = f"Invalid image request: {image_request_message}"
+ logger.exception(error)
+ raise InvalidImageRequestException(error)
+
+ self.image_request_handler.process_image_request(image_request)
+ self.image_request_queue.finish_request(receipt_handle)
+ except RetryableJobException:
+ self.image_request_queue.reset_request(receipt_handle, visibility_timeout=0)
+ except Exception as err:
+ logger.error(f"There was a problem processing the image request: {err}")
+ min_image_id = image_request.image_id if image_request and image_request.image_id else ""
+ min_job_id = image_request.job_id if image_request and image_request.job_id else ""
+ minimal_job_item = JobItem(
+ image_id=min_image_id,
+ job_id=min_job_id,
+ processing_duration=0,
+ )
+ self.image_request_handler.fail_image_request(minimal_job_item, err)
+ self.image_request_queue.finish_request(receipt_handle)
+ finally:
+ self.running = False
diff --git a/src/aws/osml/model_runner/queue/request_queue.py b/src/aws/osml/model_runner/queue/request_queue.py
index 5e539db5..1478e5fe 100755
--- a/src/aws/osml/model_runner/queue/request_queue.py
+++ b/src/aws/osml/model_runner/queue/request_queue.py
@@ -7,7 +7,7 @@
import boto3
from botocore.exceptions import ClientError
-from aws.osml.model_runner.app_config import BotoConfig
+from aws.osml.model_runner.config import BotoConfig
class RequestQueue:
diff --git a/src/aws/osml/model_runner/region_request_handler.py b/src/aws/osml/model_runner/region_request_handler.py
new file mode 100644
index 00000000..0d67a7d5
--- /dev/null
+++ b/src/aws/osml/model_runner/region_request_handler.py
@@ -0,0 +1,207 @@
+# Copyright 2023-2024 Amazon.com, Inc. or its affiliates.
+
+import logging
+from typing import Optional
+
+from aws_embedded_metrics.logger.metrics_logger import MetricsLogger
+from aws_embedded_metrics.metric_scope import metric_scope
+from aws_embedded_metrics.unit import Unit
+from osgeo import gdal
+
+from aws.osml.photogrammetry import SensorModel
+
+from .api import RegionRequest
+from .common import EndpointUtils, RequestStatus, Timer
+from .config import MetricLabels, ServiceConfig
+from .database import EndpointStatisticsTable, JobItem, JobTable, RegionRequestItem, RegionRequestTable
+from .exceptions import ProcessRegionException, SelfThrottledRegionException
+from .queue import RequestQueue
+from .status import RegionStatusMonitor
+from .tile_worker import TilingStrategy, process_tiles, setup_tile_workers
+
+# Set up logging configuration
+logger = logging.getLogger(__name__)
+
+# GDAL 4.0 will begin using exceptions as the default; at this point the software is written to assume
+# no exceptions so we call this explicitly until the software can be updated to match.
+gdal.UseExceptions()
+
+
+class RegionRequestHandler:
+ """
+ Class responsible for handling RegionRequest processing.
+ """
+
+ def __init__(
+ self,
+ region_request_table: RegionRequestTable,
+ job_table: JobTable,
+ region_status_monitor: RegionStatusMonitor,
+ endpoint_statistics_table: EndpointStatisticsTable,
+ tiling_strategy: TilingStrategy,
+ region_request_queue: RequestQueue,
+ endpoint_utils: EndpointUtils,
+ config: ServiceConfig,
+ ) -> None:
+ """
+ Initialize the RegionRequestHandler with the necessary dependencies.
+
+ :param region_request_table: The table that handles region requests.
+ :param job_table: The job table for image/region processing.
+ :param region_status_monitor: A monitor to track region request status.
+ :param endpoint_statistics_table: Table for tracking endpoint statistics.
+ :param tiling_strategy: The strategy for handling image tiling.
+ :param region_request_queue: Queue to send region requests.
+ :param endpoint_utils: Utility class for handling endpoint-related operations.
+ :param config: Configuration settings for the service.
+ """
+ self.region_request_table = region_request_table
+ self.job_table = job_table
+ self.region_status_monitor = region_status_monitor
+ self.endpoint_statistics_table = endpoint_statistics_table
+ self.tiling_strategy = tiling_strategy
+ self.region_request_queue = region_request_queue
+ self.endpoint_utils = endpoint_utils
+ self.config = config
+
+ @metric_scope
+ def process_region_request(
+ self,
+ region_request: RegionRequest,
+ region_request_item: RegionRequestItem,
+ raster_dataset: gdal.Dataset,
+ sensor_model: Optional[SensorModel] = None,
+ metrics: MetricsLogger = None,
+ ) -> JobItem:
+ """
+ Processes RegionRequest objects that are delegated for processing. Loads the specified region of an image into
+ memory to be processed by tile-workers. If a raster_dataset is not provided directly it will poll the image
+ from the region request.
+
+ :param region_request: RegionRequest = the region request
+ :param region_request_item: RegionRequestItem = the region request to update
+ :param raster_dataset: gdal.Dataset = the raster dataset containing the region
+ :param sensor_model: Optional[SensorModel] = the sensor model for this raster dataset
+ :param metrics: MetricsLogger = the metrics logger to use to report metrics.
+
+ :return: JobItem
+ """
+ if isinstance(metrics, MetricsLogger):
+ metrics.set_dimensions()
+
+ if not region_request.is_valid():
+ logger.error(f"Invalid Region Request! {region_request.__dict__}")
+ raise ValueError("Invalid Region Request")
+
+ if isinstance(metrics, MetricsLogger):
+ image_format = str(raster_dataset.GetDriver().ShortName).upper()
+ metrics.put_dimensions(
+ {
+ MetricLabels.OPERATION_DIMENSION: MetricLabels.REGION_PROCESSING_OPERATION,
+ MetricLabels.MODEL_NAME_DIMENSION: region_request.model_name,
+ MetricLabels.INPUT_FORMAT_DIMENSION: image_format,
+ }
+ )
+
+ if self.config.self_throttling:
+ max_regions = self.endpoint_utils.calculate_max_regions(
+ region_request.model_name, region_request.model_invocation_role
+ )
+ # Add entry to the endpoint statistics table
+ self.endpoint_statistics_table.upsert_endpoint(region_request.model_name, max_regions)
+ in_progress = self.endpoint_statistics_table.current_in_progress_regions(region_request.model_name)
+
+ if in_progress >= max_regions:
+ if isinstance(metrics, MetricsLogger):
+ metrics.put_metric(MetricLabels.THROTTLES, 1, str(Unit.COUNT.value))
+ logger.warning(f"Throttling region request. (Max: {max_regions} In-progress: {in_progress}")
+ raise SelfThrottledRegionException
+
+ # Increment the endpoint region counter
+ self.endpoint_statistics_table.increment_region_count(region_request.model_name)
+
+ try:
+ with Timer(
+ task_str=f"Processing region {region_request.image_url} {region_request.region_bounds}",
+ metric_name=MetricLabels.DURATION,
+ logger=logger,
+ metrics_logger=metrics,
+ ):
+ self.region_request_table.start_region_request(region_request_item)
+ logging.debug(
+ f"Starting region request: region id: {region_request_item.region_id}"
+ )
+
+ # Set up our threaded tile worker pool
+ tile_queue, tile_workers = setup_tile_workers(region_request, sensor_model, self.config.elevation_model)
+
+ # Process all our tiles
+ total_tile_count, failed_tile_count = process_tiles(
+ self.tiling_strategy,
+ region_request_item,
+ tile_queue,
+ tile_workers,
+ raster_dataset,
+ sensor_model,
+ )
+
+ # Update table w/ total tile counts
+ region_request_item.total_tiles = total_tile_count
+ region_request_item.succeeded_tile_count = total_tile_count - failed_tile_count
+ region_request_item.failed_tile_count = failed_tile_count
+ region_request_item = self.region_request_table.update_region_request(region_request_item)
+
+ # Update the image request to complete this region
+ image_request_item = self.job_table.complete_region_request(region_request.image_id, bool(failed_tile_count))
+
+ # Update region request table if that region succeeded
+ region_status = self.region_status_monitor.get_status(region_request_item)
+ region_request_item = self.region_request_table.complete_region_request(region_request_item, region_status)
+
+ self.region_status_monitor.process_event(region_request_item, region_status, "Completed region processing")
+
+ # Write CloudWatch Metrics to the Logs
+ if isinstance(metrics, MetricsLogger):
+ metrics.put_metric(MetricLabels.INVOCATIONS, 1, str(Unit.COUNT.value))
+
+ # Return the updated item
+ return image_request_item
+
+ except Exception as err:
+ failed_msg = f"Failed to process image region: {err}"
+ logger.error(failed_msg)
+ # Update the table to record the failure
+ region_request_item.message = failed_msg
+ return self.fail_region_request(region_request_item)
+
+ finally:
+ # Decrement the endpoint region counter
+ if self.config.self_throttling:
+ self.endpoint_statistics_table.decrement_region_count(region_request.model_name)
+
+ @metric_scope
+ def fail_region_request(
+ self,
+ region_request_item: RegionRequestItem,
+ metrics: MetricsLogger = None,
+ ) -> JobItem:
+ """
+ Fails a region if it failed to process successfully and updates the table accordingly before
+ raising an exception
+
+ :param region_request_item: RegionRequestItem = the region request to update
+ :param metrics: MetricsLogger = the metrics logger to use to report metrics.
+
+ :return: None
+ """
+ if isinstance(metrics, MetricsLogger):
+ metrics.put_metric(MetricLabels.ERRORS, 1, str(Unit.COUNT.value))
+ try:
+ region_status = RequestStatus.FAILED
+ region_request_item = self.region_request_table.complete_region_request(region_request_item, region_status)
+ self.region_status_monitor.process_event(region_request_item, region_status, "Completed region processing")
+ return self.job_table.complete_region_request(region_request_item.image_id, error=True)
+ except Exception as status_error:
+ logger.error("Unable to update region status in job table")
+ logger.exception(status_error)
+ raise ProcessRegionException("Failed to process image region!")
diff --git a/src/aws/osml/model_runner/sink/kinesis_sink.py b/src/aws/osml/model_runner/sink/kinesis_sink.py
index 5f4e3732..f6739490 100755
--- a/src/aws/osml/model_runner/sink/kinesis_sink.py
+++ b/src/aws/osml/model_runner/sink/kinesis_sink.py
@@ -9,8 +9,8 @@
from geojson import Feature, FeatureCollection
from aws.osml.model_runner.api import SinkMode, SinkType
-from aws.osml.model_runner.app_config import BotoConfig, ServiceConfig
from aws.osml.model_runner.common import get_credentials_for_assumed_role
+from aws.osml.model_runner.config import BotoConfig, ServiceConfig
from .exceptions import InvalidKinesisStreamException
from .sink import Sink
diff --git a/src/aws/osml/model_runner/sink/s3_sink.py b/src/aws/osml/model_runner/sink/s3_sink.py
index b7e29316..4fd273af 100755
--- a/src/aws/osml/model_runner/sink/s3_sink.py
+++ b/src/aws/osml/model_runner/sink/s3_sink.py
@@ -12,8 +12,8 @@
from geojson import Feature, FeatureCollection
from aws.osml.model_runner.api import SinkMode, SinkType
-from aws.osml.model_runner.app_config import BotoConfig
from aws.osml.model_runner.common import get_credentials_for_assumed_role
+from aws.osml.model_runner.config import BotoConfig
from .sink import Sink
diff --git a/src/aws/osml/model_runner/status/sns_helper.py b/src/aws/osml/model_runner/status/sns_helper.py
index 957849f1..1f155f59 100755
--- a/src/aws/osml/model_runner/status/sns_helper.py
+++ b/src/aws/osml/model_runner/status/sns_helper.py
@@ -5,7 +5,7 @@
import boto3
-from aws.osml.model_runner.app_config import BotoConfig
+from aws.osml.model_runner.config import BotoConfig
from .exceptions import SNSPublishException
diff --git a/src/aws/osml/model_runner/tile_worker/tile_worker.py b/src/aws/osml/model_runner/tile_worker/tile_worker.py
index 7b6badfa..2fbc8fec 100755
--- a/src/aws/osml/model_runner/tile_worker/tile_worker.py
+++ b/src/aws/osml/model_runner/tile_worker/tile_worker.py
@@ -14,8 +14,8 @@
from shapely.affinity import translate
from aws.osml.features import Geolocator, ImagedFeaturePropertyAccessor
-from aws.osml.model_runner.app_config import MetricLabels
from aws.osml.model_runner.common import ThreadingLocalContextFilter, TileState, Timer
+from aws.osml.model_runner.config import MetricLabels
from aws.osml.model_runner.database import FeatureTable, RegionRequestTable
from aws.osml.model_runner.inference import Detector
diff --git a/src/aws/osml/model_runner/tile_worker/tile_worker_utils.py b/src/aws/osml/model_runner/tile_worker/tile_worker_utils.py
index 750441d0..ff12aebd 100755
--- a/src/aws/osml/model_runner/tile_worker/tile_worker_utils.py
+++ b/src/aws/osml/model_runner/tile_worker/tile_worker_utils.py
@@ -19,8 +19,8 @@
from aws.osml.photogrammetry import ElevationModel, SensorModel
from ..api import RegionRequest
-from ..app_config import MetricLabels, ServiceConfig
from ..common import Timer, get_credentials_for_assumed_role
+from ..config import MetricLabels, ServiceConfig
from ..database import FeatureTable
from ..inference.endpoint_factory import FeatureDetectorFactory
from .exceptions import ProcessTilesException, SetupTileWorkersException
diff --git a/test/test_app.py b/test/test_app.py
index 6f7bfd0d..c01dcfa5 100755
--- a/test/test_app.py
+++ b/test/test_app.py
@@ -452,11 +452,11 @@ def test_create_elevation_model(self) -> None:
The import and reload statements are necessary to force the ServiceConfig to update with the
patched environment variables.
"""
- import aws.osml.model_runner.app_config
+ import aws.osml.model_runner.config
- reload(aws.osml.model_runner.app_config)
+ reload(aws.osml.model_runner.config)
from aws.osml.gdal.gdal_dem_tile_factory import GDALDigitalElevationModelTileFactory
- from aws.osml.model_runner.app_config import ServiceConfig
+ from aws.osml.model_runner.config import ServiceConfig
from aws.osml.photogrammetry.digital_elevation_model import DigitalElevationModel
from aws.osml.photogrammetry.srtm_dem_tile_set import SRTMTileSet