diff --git a/bin/oversightml-mr-entry-point.py b/bin/oversightml-mr-entry-point.py index 3f924c7f..ad58fff9 100644 --- a/bin/oversightml-mr-entry-point.py +++ b/bin/oversightml-mr-entry-point.py @@ -11,8 +11,8 @@ from codeguru_profiler_agent import Profiler from pythonjsonlogger import jsonlogger -from aws.osml.model_runner.app import ModelRunner from aws.osml.model_runner.common import ThreadingLocalContextFilter +from aws.osml.model_runner.model_runner import ModelRunner def handler_stop_signals(signal_num: int, frame: Optional[FrameType], model_runner: ModelRunner) -> None: diff --git a/src/aws/osml/model_runner/api/image_request.py b/src/aws/osml/model_runner/api/image_request.py index e1a7a2d2..6a1d4b87 100755 --- a/src/aws/osml/model_runner/api/image_request.py +++ b/src/aws/osml/model_runner/api/image_request.py @@ -9,7 +9,6 @@ import shapely.wkt from shapely.geometry.base import BaseGeometry -from aws.osml.model_runner.app_config import BotoConfig from aws.osml.model_runner.common import ( FeatureDistillationAlgorithm, FeatureDistillationNMS, @@ -21,6 +20,7 @@ deserialize_post_processing_list, get_credentials_for_assumed_role, ) +from aws.osml.model_runner.config import BotoConfig from .exceptions import InvalidS3ObjectException from .inference import ModelInvokeMode diff --git a/src/aws/osml/model_runner/common/credentials_utils.py b/src/aws/osml/model_runner/common/credentials_utils.py index f7e09c07..95eb0050 100755 --- a/src/aws/osml/model_runner/common/credentials_utils.py +++ b/src/aws/osml/model_runner/common/credentials_utils.py @@ -4,7 +4,7 @@ import boto3 -from aws.osml.model_runner.app_config import BotoConfig +from aws.osml.model_runner.config import BotoConfig from .exceptions import InvalidAssumedRoleException diff --git a/src/aws/osml/model_runner/common/endpoint_utils.py b/src/aws/osml/model_runner/common/endpoint_utils.py index 40cac830..b02a159e 100755 --- a/src/aws/osml/model_runner/common/endpoint_utils.py +++ b/src/aws/osml/model_runner/common/endpoint_utils.py @@ -8,7 +8,7 @@ import boto3 from cachetools import TTLCache, cachedmethod -from aws.osml.model_runner.app_config import BotoConfig, ServiceConfig +from aws.osml.model_runner.config import BotoConfig, ServiceConfig from .credentials_utils import get_credentials_for_assumed_role diff --git a/src/aws/osml/model_runner/app_config.py b/src/aws/osml/model_runner/config.py similarity index 100% rename from src/aws/osml/model_runner/app_config.py rename to src/aws/osml/model_runner/config.py diff --git a/src/aws/osml/model_runner/database/ddb_helper.py b/src/aws/osml/model_runner/database/ddb_helper.py index 34835260..670334ff 100755 --- a/src/aws/osml/model_runner/database/ddb_helper.py +++ b/src/aws/osml/model_runner/database/ddb_helper.py @@ -10,7 +10,7 @@ import boto3 from boto3.dynamodb.conditions import Key -from aws.osml.model_runner.app_config import BotoConfig +from aws.osml.model_runner.config import BotoConfig from .exceptions import DDBBatchWriteException, DDBUpdateException diff --git a/src/aws/osml/model_runner/database/feature_table.py b/src/aws/osml/model_runner/database/feature_table.py index 0c88dea3..1638592b 100755 --- a/src/aws/osml/model_runner/database/feature_table.py +++ b/src/aws/osml/model_runner/database/feature_table.py @@ -16,8 +16,8 @@ from dacite import from_dict from geojson import Feature -from aws.osml.model_runner.app_config import MetricLabels, ServiceConfig from aws.osml.model_runner.common import ImageDimensions, Timer, get_feature_image_bounds +from aws.osml.model_runner.config import MetricLabels, ServiceConfig from .ddb_helper import DDBHelper, DDBItem, DDBKey from .exceptions import AddFeaturesException diff --git a/src/aws/osml/model_runner/app.py b/src/aws/osml/model_runner/image_request_handler.py old mode 100755 new mode 100644 similarity index 55% rename from src/aws/osml/model_runner/app.py rename to src/aws/osml/model_runner/image_request_handler.py index a7da9f6d..d14172d6 --- a/src/aws/osml/model_runner/app.py +++ b/src/aws/osml/model_runner/image_request_handler.py @@ -11,18 +11,16 @@ from typing import Any, Dict, List, Optional, Tuple import shapely.geometry.base -from aws_embedded_metrics.logger.metrics_logger import MetricsLogger -from aws_embedded_metrics.metric_scope import metric_scope +from aws_embedded_metrics import MetricsLogger, metric_scope from aws_embedded_metrics.unit import Unit from geojson import Feature from osgeo import gdal from osgeo.gdal import Dataset -from aws.osml.gdal import GDALConfigEnv, get_image_extension, load_gdal_dataset, set_gdal_default_configuration +from aws.osml.gdal import GDALConfigEnv, get_image_extension, load_gdal_dataset from aws.osml.photogrammetry import ImageCoordinate, SensorModel from .api import VALID_MODEL_HOSTING_OPTIONS, ImageRequest, InvalidImageRequestException, RegionRequest, SinkMode -from .app_config import MetricLabels, ServiceConfig from .common import ( EndpointUtils, FeatureDistillationDeserializer, @@ -30,11 +28,11 @@ ImageDimensions, ImageRegion, RequestStatus, - ThreadingLocalContextFilter, Timer, get_credentials_for_assumed_role, mr_post_processing_options_factory, ) +from .config import MetricLabels, ServiceConfig from .database import EndpointStatisticsTable, FeatureTable, JobItem, JobTable, RegionRequestItem, RegionRequestTable from .exceptions import ( AggregateFeaturesException, @@ -43,16 +41,14 @@ InvalidImageURLException, LoadImageException, ProcessImageException, - ProcessRegionException, - RetryableJobException, - SelfThrottledRegionException, UnsupportedModelException, ) from .inference import FeatureSelector, calculate_processing_bounds, get_source_property from .queue import RequestQueue +from .region_request_handler import RegionRequestHandler from .sink import SinkFactory -from .status import ImageStatusMonitor, RegionStatusMonitor -from .tile_worker import TilingStrategy, VariableOverlapTilingStrategy, process_tiles, setup_tile_workers +from .status import ImageStatusMonitor +from .tile_worker import TilingStrategy # Set up logging configuration logger = logging.getLogger(__name__) @@ -62,177 +58,49 @@ gdal.UseExceptions() -class ModelRunner: +class ImageRequestHandler: """ - Main class for operating the ModelRunner application. It monitors input queues for processing requests, - decomposes the image into a set of smaller regions and tiles, invokes an ML model endpoint with each tile, and - finally aggregates all the results into a single output which can be deposited into the desired output sinks. + Class responsible for handling ImageRequest processing. """ - def __init__(self, tiling_strategy: TilingStrategy = VariableOverlapTilingStrategy()) -> None: - """ - Initialize a model runner with the injectable behaviors. - - :param tiling_strategy: class defining how a larger image will be broken into chunks for processing - """ - self.config = ServiceConfig() - self.tiling_strategy = tiling_strategy - self.image_request_queue = RequestQueue(self.config.image_queue, wait_seconds=0) - self.image_requests_iter = iter(self.image_request_queue) - self.job_table = JobTable(self.config.job_table) - self.region_request_table = RegionRequestTable(self.config.region_request_table) - self.endpoint_statistics_table = EndpointStatisticsTable(self.config.endpoint_statistics_table) - self.region_request_queue = RequestQueue(self.config.region_queue, wait_seconds=10) - self.region_requests_iter = iter(self.region_request_queue) - self.image_status_monitor = ImageStatusMonitor(self.config.image_status_topic) - self.region_status_monitor = RegionStatusMonitor(self.config.region_status_topic) - self.endpoint_utils = EndpointUtils() - self.running = False - - def run(self) -> None: - """ - Starts ModelRunner in a loop that continuously monitors the image work queue and region work queue. - - :return: None - """ - self.monitor_work_queues() - - def stop(self) -> None: - """ - Stops ModelRunner by setting the global run variable to False. - - :return: None + def __init__( + self, + job_table: JobTable, + image_status_monitor: ImageStatusMonitor, + endpoint_statistics_table: EndpointStatisticsTable, + tiling_strategy: TilingStrategy, + region_request_queue: RequestQueue, + region_request_table: RegionRequestTable, + endpoint_utils: EndpointUtils, + config: ServiceConfig, + region_request_handler: RegionRequestHandler, + ) -> None: """ - self.running = False + Initialize the ImageRequestHandler with the necessary dependencies. - def monitor_work_queues(self) -> None: + :param job_table: The job table for image processing. + :param image_status_monitor: A monitor to track image request status. + :param endpoint_statistics_table: Table for tracking endpoint statistics. + :param tiling_strategy: The strategy for handling image tiling. + :param region_request_queue: Queue to send region requests. + :param region_request_table: Table to track region request progress. + :param endpoint_utils: Utility class for handling endpoint-related operations. + :param config: Configuration settings for the service. """ - Monitors SQS queues for ImageRequest and RegionRequest The region work queue is checked first and will wait for - up to 10 seconds to start work. Only if no regions need to be processed in that time will this worker check to - see if a new image can be started. Ultimately this setup is intended to ensure that all the regions for an image - are completed by the cluster before work begins on more images. - :return: None - """ - # Set the running state to True - self.running = True - - # Set up the GDAL configuration options that should remain unchanged for the life of this execution - set_gdal_default_configuration() - try: - while self.running: - logger.debug("Checking work queue for regions to process ...") - (receipt_handle, region_request_attributes) = next(self.region_requests_iter) - ThreadingLocalContextFilter.set_context(region_request_attributes) - - # If we found a region request on the queue - if region_request_attributes is not None: - try: - # Parse the message into a working RegionRequest - region_request = RegionRequest(region_request_attributes) - - # If the image request has a s3 url lets augment its path for virtual hosting - if "s3:/" in region_request.image_url: - # Validate that image exists in S3 - ImageRequest.validate_image_path(region_request.image_url, region_request.image_read_role) - image_path = region_request.image_url.replace("s3:/", "/vsis3", 1) - else: - image_path = region_request.image_url - - # Load the image into a GDAL dataset - raster_dataset, sensor_model = load_gdal_dataset(image_path) - image_format = str(raster_dataset.GetDriver().ShortName).upper() - - # Get RegionRequestItem if not create new RegionRequestItem - region_request_item = self.region_request_table.get_region_request( - region_request.region_id, region_request.image_id - ) - if region_request_item is None: - # Create a new item from the region request - region_request_item = RegionRequestItem.from_region_request(region_request) - - # Add the item to the table and start it processing - self.region_request_table.start_region_request(region_request_item) - logging.debug( - ( - f"Adding region request: image id: {region_request_item.image_id} - " - f"region id: {region_request_item.region_id}" - ) - ) - - # Process our region request - image_request_item = self.process_region_request( - region_request, region_request_item, raster_dataset, sensor_model - ) - - # Check if the image is complete - if self.job_table.is_image_request_complete(image_request_item): - # If so complete the image request - self.complete_image_request(region_request, image_format, raster_dataset, sensor_model) - - # Update the queue - self.region_request_queue.finish_request(receipt_handle) - except RetryableJobException: - self.region_request_queue.reset_request(receipt_handle, visibility_timeout=0) - except SelfThrottledRegionException: - self.region_request_queue.reset_request( - receipt_handle, - visibility_timeout=int(self.config.throttling_retry_timeout), - ) - except Exception as err: - logger.error(f"There was a problem processing the region request: {err}") - self.region_request_queue.finish_request(receipt_handle) - else: - logger.debug("Checking work queue for images to process ...") - (receipt_handle, image_request_message) = next(self.image_requests_iter) - - # If we found a request on the queue - if image_request_message is not None: - image_request = None - try: - # Parse the message into a working ImageRequest - image_request = ImageRequest.from_external_message(image_request_message) - ThreadingLocalContextFilter.set_context(image_request.__dict__) - - # Check that our image request looks good - if not image_request.is_valid(): - error = f"Invalid image request: {image_request_message}" - logger.exception(error) - raise InvalidImageRequestException(error) - - # Process the request - self.process_image_request(image_request) - - # Update the queue - self.image_request_queue.finish_request(receipt_handle) - except RetryableJobException: - self.image_request_queue.reset_request(receipt_handle, visibility_timeout=0) - except Exception as err: - logger.error(f"There was a problem processing the image request: {err}") - min_image_id = image_request.image_id if image_request and image_request.image_id else "" - min_job_id = image_request.job_id if image_request and image_request.job_id else "" - minimal_job_item = JobItem( - image_id=min_image_id, - job_id=min_job_id, - processing_duration=0, - ) - self.fail_image_request_send_messages(minimal_job_item, err) - self.image_request_queue.finish_request(receipt_handle) - finally: - # If we stop monitoring the queue set run state to false - self.running = False + self.job_table = job_table + self.image_status_monitor = image_status_monitor + self.endpoint_statistics_table = endpoint_statistics_table + self.tiling_strategy = tiling_strategy + self.region_request_queue = region_request_queue + self.region_request_table = region_request_table + self.endpoint_utils = endpoint_utils + self.config = config + self.region_request_handler = region_request_handler def process_image_request(self, image_request: ImageRequest) -> None: """ - Processes ImageRequest objects that are picked up from queue. Loads the specified image into memory to be - chipped apart into regions and sent downstream for processing via RegionRequest. This will also process the - first region chipped from the image. # This worker will process the first region of this image since it has - already loaded the dataset from S3 and is ready to go. Any additional regions will be queued for processing by - other workers in this cluster. - - :param image_request: ImageRequest = the image request derived from the ImageRequest SQS message - - :return: None + Processes ImageRequest objects picked up from the queue. """ image_request_item = None try: @@ -243,7 +111,6 @@ def process_image_request(self, image_request: ImageRequest) -> None: # Add entry to the endpoint statistics table self.endpoint_statistics_table.upsert_endpoint(image_request.model_name, max_regions) - # Update the image status to started and include relevant image meta-data logger.debug(f"Starting processing of {image_request.image_url}") image_request_item = JobItem( image_id=image_request.image_id, @@ -264,14 +131,11 @@ def process_image_request(self, image_request: ImageRequest) -> None: asdict(feature_distillation_option_list[0], dict_factory=mr_post_processing_options_factory) ) - # Start the image processing self.job_table.start_image_request(image_request_item) self.image_status_monitor.process_event(image_request_item, RequestStatus.STARTED, "Started image request") - # Check we have a valid image request, throws if not self.validate_model_hosting(image_request_item) - # Load the relevant image meta data into memory image_extension, raster_dataset, sensor_model, all_regions = self.load_image_request( image_request_item, image_request.roi ) @@ -281,38 +145,28 @@ def process_image_request(self, image_request: ImageRequest) -> None: f"Dataset {image_request_item.image_id} has no geo transform. Results are not geo-referenced." ) - # If we got valid outputs if raster_dataset and all_regions and image_extension: image_request_item.region_count = len(all_regions) image_request_item.width = int(raster_dataset.RasterXSize) image_request_item.height = int(raster_dataset.RasterYSize) try: - image_request_item.extents = json.dumps(ModelRunner.get_extents(raster_dataset, sensor_model)) + image_request_item.extents = json.dumps(self.get_extents(raster_dataset, sensor_model)) except Exception as e: logger.warning(f"Could not get extents for image: {image_request_item.image_id}") logger.exception(e) feature_properties: List[dict] = json.loads(image_request_item.feature_properties) - - # If we can get a valid source metadata from the source image - attach it to features - # else, just pass in whatever custom features if they were provided source_metadata = get_source_property(image_request_item.image_url, image_extension, raster_dataset) if isinstance(source_metadata, dict): feature_properties.append(source_metadata) - # Update the feature properties image_request_item.feature_properties = json.dumps(feature_properties) - - # Update the image request job to have new derived image data self.job_table.update_image_request(image_request_item) - self.image_status_monitor.process_event(image_request_item, RequestStatus.IN_PROGRESS, "Processing regions") - # Place the resulting region requests on the appropriate work queue self.queue_region_request(all_regions, image_request, raster_dataset, sensor_model, image_extension) except Exception as err: - # We failed try and gracefully update our image request if image_request_item: self.fail_image_request(image_request_item, err) else: @@ -322,186 +176,49 @@ def process_image_request(self, image_request: ImageRequest) -> None: processing_duration=0, ) self.fail_image_request(minimal_job_item, err) - - # Let the application know that we failed to process image raise ProcessImageException("Failed to process image region!") from err - def queue_region_request( - self, - all_regions: List[ImageRegion], - image_request: ImageRequest, - raster_dataset: Dataset, - sensor_model: Optional[SensorModel], - image_extension: Optional[str], + def complete_image_request( + self, region_request: RegionRequest, image_format: str, raster_dataset: gdal.Dataset, sensor_model: SensorModel ) -> None: """ - Loads the list of regions into the queue. First it will create a RequestRequestItem and creates - an entry into the RegionRequestTable for traceability. Then process the region request. Once it's completed, - it will update an entry in the RegionRequestTable. - - :param image_extension: = the GDAL derived image extension - :param all_regions: List[ImageRegion] = the list of image regions - :param image_request: ImageRequest = the image request - :param raster_dataset: Dataset = the raster dataset containing the region - :param sensor_model: Optional[SensorModel] = the sensor model for this raster dataset - - :return: None - """ - # Set aside the first region - first_region = all_regions.pop(0) - for region in all_regions: - logger.debug(f"Queueing region: {region}") - - region_request = RegionRequest( - image_request.get_shared_values(), - region_bounds=region, - region_id=f"{region[0]}{region[1]}-{image_request.job_id}", - image_extension=image_extension, - ) - - # Create a new entry to the region request being started - region_request_item = RegionRequestItem.from_region_request(region_request) - self.region_request_table.start_region_request(region_request_item) - logging.debug( - ( - f"Adding region request: image id: {region_request_item.image_id} - " - f"region id: {region_request_item.region_id}" - ) - ) - - # Send the attributes of the region request as the message. - self.region_request_queue.send_request(region_request.__dict__) - - # Go ahead and process the first region - logger.debug(f"Processing first region {0}: {first_region}") - - first_region_request = RegionRequest( - image_request.get_shared_values(), - region_bounds=first_region, - region_id=f"{first_region[0]}{first_region[1]}-{image_request.job_id}", - image_extension=image_extension, - ) - - # Add item to RegionRequestTable - first_region_request_item = RegionRequestItem.from_region_request(first_region_request) - self.region_request_table.start_region_request(first_region_request_item) - logging.debug(f"Adding region_id: {first_region_request_item.region_id}") - - # Processes our region request and return the updated item - image_request_item = self.process_region_request( - first_region_request, first_region_request_item, raster_dataset, sensor_model - ) - - # If the image is finished then complete it - if self.job_table.is_image_request_complete(image_request_item): - image_format = str(raster_dataset.GetDriver().ShortName).upper() - self.complete_image_request(first_region_request, image_format, raster_dataset, sensor_model) - - @metric_scope - def process_region_request( - self, - region_request: RegionRequest, - region_request_item: RegionRequestItem, - raster_dataset: gdal.Dataset, - sensor_model: Optional[SensorModel] = None, - metrics: MetricsLogger = None, - ) -> JobItem: - """ - Processes RegionRequest objects that are delegated for processing. Loads the specified region of an image into - memory to be processed by tile-workers. If a raster_dataset is not provided directly it will poll the image - from the region request. - - :param region_request: RegionRequest = the region request - :param region_request_item: RegionRequestItem = the region request to update - :param raster_dataset: gdal.Dataset = the raster dataset containing the region - :param sensor_model: Optional[SensorModel] = the sensor model for this raster dataset - :param metrics: MetricsLogger = the metrics logger to use to report metrics. - - :return: None + Runs after every region has completed processing to check if that was the last region and run required + completion logic for the associated ImageRequest. """ - if isinstance(metrics, MetricsLogger): - metrics.set_dimensions() - - if not region_request.is_valid(): - logger.error(f"Invalid Region Request! {region_request.__dict__}") - raise ValueError("Invalid Region Request") - - if isinstance(metrics, MetricsLogger): - image_format = str(raster_dataset.GetDriver().ShortName).upper() - metrics.put_dimensions( - { - MetricLabels.OPERATION_DIMENSION: MetricLabels.REGION_PROCESSING_OPERATION, - MetricLabels.MODEL_NAME_DIMENSION: region_request.model_name, - MetricLabels.INPUT_FORMAT_DIMENSION: image_format, - } - ) - - if self.config.self_throttling: - max_regions = self.endpoint_utils.calculate_max_regions( - region_request.model_name, region_request.model_invocation_role - ) - # Add entry to the endpoint statistics table - self.endpoint_statistics_table.upsert_endpoint(region_request.model_name, max_regions) - in_progress = self.endpoint_statistics_table.current_in_progress_regions(region_request.model_name) - - if in_progress >= max_regions: - if isinstance(metrics, MetricsLogger): - metrics.put_metric(MetricLabels.THROTTLES, 1, str(Unit.COUNT.value)) - logger.warning(f"Throttling region request. (Max: {max_regions} In-progress: {in_progress}") - raise SelfThrottledRegionException - - # Increment the endpoint region counter - self.endpoint_statistics_table.increment_region_count(region_request.model_name) - try: - with Timer( - task_str=f"Processing region {region_request.image_url} {region_request.region_bounds}", - metric_name=MetricLabels.DURATION, - logger=logger, - metrics_logger=metrics, - ): - # Set up our threaded tile worker pool - tile_queue, tile_workers = setup_tile_workers(region_request, sensor_model, self.config.elevation_model) - - # Process all our tiles - total_tile_count, failed_tile_count = process_tiles( - self.tiling_strategy, region_request_item, tile_queue, tile_workers, raster_dataset, sensor_model - ) + image_request_item = self.job_table.get_image_request(region_request.image_id) - # Update table w/ total tile counts - region_request_item.total_tiles = total_tile_count - region_request_item.succeeded_tile_count = total_tile_count - failed_tile_count - region_request_item.failed_tile_count = failed_tile_count - region_request_item = self.region_request_table.update_region_request(region_request_item) + logger.debug("Last region of image request was completed, aggregating features for image!") - # Update the image request to complete this region - image_request_item = self.job_table.complete_region_request(region_request.image_id, bool(failed_tile_count)) + roi = None + if image_request_item.roi_wkt: + logger.debug(f"Using ROI from request to set processing boundary: {image_request_item.roi_wkt}") + roi = shapely.wkt.loads(image_request_item.roi_wkt) + processing_bounds = calculate_processing_bounds(raster_dataset, roi, sensor_model) + logger.debug(f"Processing boundary from {roi} is {processing_bounds}") - # Update region request table if that region succeeded - region_status = self.region_status_monitor.get_status(region_request_item) - region_request_item = self.region_request_table.complete_region_request(region_request_item, region_status) + feature_table = FeatureTable(self.config.feature_table, region_request.tile_size, region_request.tile_overlap) + features = feature_table.aggregate_features(image_request_item) + features = self.select_features(image_request_item, features, processing_bounds) + features = self.add_properties_to_features(image_request_item, features) - self.region_status_monitor.process_event(region_request_item, region_status, "Completed region processing") + is_write_succeeded = self.sink_features(image_request_item, features) + if not is_write_succeeded: + raise AggregateOutputFeaturesException("Failed to write features to S3 or Kinesis! Please check the log...") - # Write CloudWatch Metrics to the Logs - if isinstance(metrics, MetricsLogger): - # TODO: Consider adding the +1 invocation to timer - metrics.put_metric(MetricLabels.INVOCATIONS, 1, str(Unit.COUNT.value)) + completed_image_request_item = self.job_table.end_image_request(image_request_item.image_id) - # Return the updated item - return image_request_item + if completed_image_request_item.processing_duration is not None: + image_request_status = self.image_status_monitor.get_status(completed_image_request_item) + self.image_status_monitor.process_event( + completed_image_request_item, image_request_status, "Completed image processing" + ) + self.generate_image_processing_metrics(completed_image_request_item, image_format) + else: + raise InvalidImageRequestException("ImageRequest has no start time") except Exception as err: - failed_msg = f"Failed to process image region: {err}" - logger.error(failed_msg) - # update the table to take in that exception - region_request_item.message = failed_msg - return self.fail_region_request(region_request_item) - - finally: - # Decrement the endpoint region counter - if self.config.self_throttling: - self.endpoint_statistics_table.decrement_region_count(region_request.model_name) + raise AggregateFeaturesException("Failed to aggregate features for region!") from err def load_image_request( self, @@ -570,85 +287,65 @@ def load_image_request( def fail_image_request(self, image_request_item: JobItem, err: Exception) -> None: """ - Handles failure events/exceptions for image requests and tries to update the status monitor accordingly + Handles failure events/exceptions for image requests and tries to update the status monitor accordingly. :param image_request_item: JobItem = the image request that failed. :param err: Exception = the exception that caused the failure - :return: None """ - self.fail_image_request_send_messages(image_request_item, err) + self.fail_image_request(image_request_item, err) self.job_table.end_image_request(image_request_item.image_id) - def fail_image_request_send_messages(self, image_request_item: JobItem, err: Exception) -> None: - """ - Updates failed metrics and update the status monitor accordingly - - :param image_request_item: JobItem = the image request that failed. - :param err: Exception = the exception that caused the failure - - :return: None - """ - logger.exception(f"Failed to start image processing!: {err}") - self.image_status_monitor.process_event(image_request_item, RequestStatus.FAILED, str(err)) - - def complete_image_request( - self, region_request: RegionRequest, image_format: str, raster_dataset: gdal.Dataset, sensor_model: SensorModel + def queue_region_request( + self, + all_regions: List[ImageRegion], + image_request: ImageRequest, + raster_dataset: Dataset, + sensor_model: Optional[SensorModel], + image_extension: Optional[str], ) -> None: """ - Runs after every region has completed processing to check if that was the last region and run required - completion logic for the associated ImageRequest. + Queues region requests and handles processing of the first region. + """ + first_region = all_regions.pop(0) + for region in all_regions: + logger.debug(f"Queueing region: {region}") - :param region_request: RegionRequest = the region request to update. - :param image_format: Format of the image data - :param raster_dataset: the image data rater - :param sensor_model: the image sensor model + region_request = RegionRequest( + image_request.get_shared_values(), + region_bounds=region, + region_id=f"{region[0]}{region[1]}-{image_request.job_id}", + image_extension=image_extension, + ) - :return: None - """ - try: - # Grab the full image request item from the table - image_request_item = self.job_table.get_image_request(region_request.image_id) + region_request_item = RegionRequestItem.from_region_request(region_request) + self.region_request_table.start_region_request(region_request_item) - logger.debug("Last region of image request was completed, aggregating features for image!") + logger.debug( + f"Adding region request: image id: {region_request_item.image_id} - " + f"region id: {region_request_item.region_id}" + ) - roi = None - if image_request_item.roi_wkt: - logger.debug(f"Using ROI from request to set processing boundary: {image_request_item.roi_wkt}") - roi = shapely.wkt.loads(image_request_item.roi_wkt) - processing_bounds = calculate_processing_bounds(raster_dataset, roi, sensor_model) - logger.debug(f"Processing boundary from {roi} is {processing_bounds}") + self.region_request_queue.send_request(region_request.__dict__) - # Set up our feature table to work with the region quest - feature_table = FeatureTable(self.config.feature_table, region_request.tile_size, region_request.tile_overlap) - # Aggregate all the features from our job - features = feature_table.aggregate_features(image_request_item) - features = self.select_features(image_request_item, features, processing_bounds) - features = self.add_properties_to_features(image_request_item, features) + logger.debug(f"Processing first region {first_region}: {first_region}") - # Sink the features into the right outputs - is_write_succeeded = self.sink_features(image_request_item, features) - if not is_write_succeeded: - raise AggregateOutputFeaturesException( - "Failed to write features to S3 or Kinesis! Please check the " "log..." - ) + first_region_request = RegionRequest( + image_request.get_shared_values(), + region_bounds=first_region, + region_id=f"{first_region[0]}{first_region[1]}-{image_request.job_id}", + image_extension=image_extension, + ) - # Put our end time on our image_request_item - completed_image_request_item = self.job_table.end_image_request(image_request_item.image_id) + first_region_request_item = RegionRequestItem.from_region_request(first_region_request) - # Ensure we have a valid start time for our record - # TODO: Figure out why we wouldn't have a valid start time?!?! - if completed_image_request_item.processing_duration is not None: - image_request_status = self.image_status_monitor.get_status(completed_image_request_item) - self.image_status_monitor.process_event( - completed_image_request_item, image_request_status, "Completed image processing" - ) - self.generate_image_processing_metrics(completed_image_request_item, image_format) - else: - raise InvalidImageRequestException("ImageRequest has no start time") + image_request_item = self.region_request_handler.process_region_request( + first_region_request, first_region_request_item, raster_dataset, sensor_model + ) - except Exception as err: - raise AggregateFeaturesException("Failed to aggregate features for region!") from err + if self.job_table.is_image_request_complete(image_request_item): + image_format = str(raster_dataset.GetDriver().ShortName).upper() + self.complete_image_request(first_region_request, image_format, raster_dataset, sensor_model) @metric_scope def generate_image_processing_metrics( @@ -680,49 +377,6 @@ def generate_image_processing_metrics( if image_request_item.region_error > 0: metrics.put_metric(MetricLabels.ERRORS, 1, str(Unit.COUNT.value)) - def fail_region_request( - self, - region_request_item: RegionRequestItem, - metrics: MetricsLogger = None, - ) -> JobItem: - """ - Fails a region if it failed to process successfully and updates the table accordingly before - raising an exception - - :param region_request_item: RegionRequestItem = the region request to update - :param metrics: MetricsLogger = the metrics logger to use to report metrics. - - :return: None - """ - if isinstance(metrics, MetricsLogger): - metrics.put_metric(MetricLabels.ERRORS, 1, str(Unit.COUNT.value)) - try: - region_status = RequestStatus.FAILED - region_request_item = self.region_request_table.complete_region_request(region_request_item, region_status) - self.region_status_monitor.process_event(region_request_item, region_status, "Completed region processing") - return self.job_table.complete_region_request(region_request_item.image_id, error=True) - except Exception as status_error: - logger.error("Unable to update region status in job table") - logger.exception(status_error) - raise ProcessRegionException("Failed to process image region!") - - def validate_model_hosting(self, image_request: JobItem): - """ - Validates that the image request is valid. If not, raises an exception. - - :param image_request: JobItem = the image request - - :return: None - """ - if not image_request.model_invoke_mode or image_request.model_invoke_mode not in VALID_MODEL_HOSTING_OPTIONS: - error = f"Application only supports ${VALID_MODEL_HOSTING_OPTIONS} Endpoints" - self.image_status_monitor.process_event( - image_request, - RequestStatus.FAILED, - error, - ) - raise UnsupportedModelException(error) - @metric_scope def select_features( self, @@ -824,16 +478,16 @@ def sink_features(image_request_item: JobItem, features: List[Feature], metrics: # Log them let them know if both written to both outputs (S3 and Kinesis) or one in another # If both couldn't write to either stream because both were down, return False. Otherwise True if tracking_output_sinks["S3"] and not tracking_output_sinks["Kinesis"]: - logging.debug("OSMLModelRunner was able to write the features to S3 but not Kinesis. Continuing...") + logging.debug("LModelRunner was able to write the features to S3 but not Kinesis. Continuing...") return True elif not tracking_output_sinks["S3"] and tracking_output_sinks["Kinesis"]: - logging.debug("OSMLModelRunner was able to write the features to Kinesis but not S3. Continuing...") + logging.debug("ModelRunner was able to write the features to Kinesis but not S3. Continuing...") return True elif tracking_output_sinks["S3"] and tracking_output_sinks["Kinesis"]: - logging.debug("OSMLModelRunner was able to write the features to both S3 and Kinesis. Continuing...") + logging.debug("ModelRunner was able to write the features to both S3 and Kinesis. Continuing...") return True else: - logging.error("OSMLModelRunner was not able to write the features to either S3 or Kinesis. Failing...") + logging.error("ModelRunner was not able to write the features to either S3 or Kinesis. Failing...") return False else: raise InvalidImageRequestException("No output destinations were defined for this image request!") @@ -898,6 +552,23 @@ def get_inference_metadata_property(image_request_item: JobItem, inference_time: } return inference_metadata_property + def validate_model_hosting(self, image_request: JobItem): + """ + Validates that the image request is valid. If not, raises an exception. + + :param image_request: JobItem = the image request + + :return: None + """ + if not image_request.model_invoke_mode or image_request.model_invoke_mode not in VALID_MODEL_HOSTING_OPTIONS: + error = f"Application only supports ${VALID_MODEL_HOSTING_OPTIONS} Endpoints" + self.image_status_monitor.process_event( + image_request, + RequestStatus.FAILED, + error, + ) + raise UnsupportedModelException(error) + @staticmethod def get_extents(ds: gdal.Dataset, sm: SensorModel) -> Dict[str, Any]: """ diff --git a/src/aws/osml/model_runner/inference/http_detector.py b/src/aws/osml/model_runner/inference/http_detector.py index 2bf46250..b922a9ae 100644 --- a/src/aws/osml/model_runner/inference/http_detector.py +++ b/src/aws/osml/model_runner/inference/http_detector.py @@ -16,8 +16,8 @@ from urllib3.util.retry import Retry from aws.osml.model_runner.api import ModelInvokeMode -from aws.osml.model_runner.app_config import MetricLabels from aws.osml.model_runner.common import Timer +from aws.osml.model_runner.config import MetricLabels from .detector import Detector from .endpoint_builder import FeatureEndpointBuilder diff --git a/src/aws/osml/model_runner/inference/sm_detector.py b/src/aws/osml/model_runner/inference/sm_detector.py index 093518b0..ca840123 100644 --- a/src/aws/osml/model_runner/inference/sm_detector.py +++ b/src/aws/osml/model_runner/inference/sm_detector.py @@ -14,8 +14,8 @@ from geojson import FeatureCollection from aws.osml.model_runner.api import ModelInvokeMode -from aws.osml.model_runner.app_config import BotoConfig, MetricLabels from aws.osml.model_runner.common import Timer +from aws.osml.model_runner.config import BotoConfig, MetricLabels from .detector import Detector from .endpoint_builder import FeatureEndpointBuilder diff --git a/src/aws/osml/model_runner/model_runner.py b/src/aws/osml/model_runner/model_runner.py new file mode 100755 index 00000000..7461c191 --- /dev/null +++ b/src/aws/osml/model_runner/model_runner.py @@ -0,0 +1,174 @@ +# Copyright 2023-2024 Amazon.com, Inc. or its affiliates. + +import logging + +from osgeo import gdal + +from aws.osml.gdal import load_gdal_dataset, set_gdal_default_configuration + +from .api import ImageRequest, InvalidImageRequestException, RegionRequest +from .common import EndpointUtils, ThreadingLocalContextFilter +from .config import ServiceConfig +from .database import EndpointStatisticsTable, JobItem, JobTable, RegionRequestItem, RegionRequestTable +from .exceptions import RetryableJobException, SelfThrottledRegionException +from .image_request_handler import ImageRequestHandler +from .queue import RequestQueue +from .region_request_handler import RegionRequestHandler +from .status import ImageStatusMonitor, RegionStatusMonitor +from .tile_worker import TilingStrategy, VariableOverlapTilingStrategy + +# Set up logging configuration +logger = logging.getLogger(__name__) + +# GDAL 4.0 will begin using exceptions as the default; at this point the software is written to assume +# no exceptions so we call this explicitly until the software can be updated to match. +gdal.UseExceptions() + + +class ModelRunner: + """ + Main class for operating the ModelRunner application. It monitors input queues for processing requests, + decomposes the image into a set of smaller regions and tiles, invokes an ML model endpoint with each tile, and + finally aggregates all the results into a single output which can be deposited into the desired output sinks. + """ + + def __init__(self, tiling_strategy: TilingStrategy = VariableOverlapTilingStrategy()) -> None: + """ + Initialize a model runner with the injectable behaviors. + + :param tiling_strategy: class defining how a larger image will be broken into chunks for processing + """ + self.config = ServiceConfig() + self.tiling_strategy = tiling_strategy + self.image_request_queue = RequestQueue(self.config.image_queue, wait_seconds=0) + self.image_requests_iter = iter(self.image_request_queue) + self.job_table = JobTable(self.config.job_table) + self.region_request_table = RegionRequestTable(self.config.region_request_table) + self.endpoint_statistics_table = EndpointStatisticsTable(self.config.endpoint_statistics_table) + self.region_request_queue = RequestQueue(self.config.region_queue, wait_seconds=10) + self.region_requests_iter = iter(self.region_request_queue) + self.image_status_monitor = ImageStatusMonitor(self.config.image_status_topic) + self.region_status_monitor = RegionStatusMonitor(self.config.region_status_topic) + self.endpoint_utils = EndpointUtils() + self.running = False + + # Pass dependencies into RegionRequestHandler + self.region_request_handler = RegionRequestHandler( + region_request_table=self.region_request_table, + job_table=self.job_table, + region_status_monitor=self.region_status_monitor, + endpoint_statistics_table=self.endpoint_statistics_table, + tiling_strategy=self.tiling_strategy, + region_request_queue=self.region_request_queue, + endpoint_utils=self.endpoint_utils, + config=self.config, + ) + + # Pass dependencies into ImageRequestHandler + self.image_request_handler = ImageRequestHandler( + job_table=self.job_table, + image_status_monitor=self.image_status_monitor, + endpoint_statistics_table=self.endpoint_statistics_table, + tiling_strategy=self.tiling_strategy, + region_request_queue=self.region_request_queue, + region_request_table=self.region_request_table, + endpoint_utils=self.endpoint_utils, + config=self.config, + region_request_handler=self.region_request_handler, + ) + + def run(self) -> None: + """ + Starts ModelRunner in a loop that continuously monitors the image work queue and region work queue. + """ + self.monitor_work_queues() + + def stop(self) -> None: + """ + Stops ModelRunner by setting the global run variable to False. + """ + self.running = False + + def monitor_work_queues(self) -> None: + """ + Monitors SQS queues for ImageRequest and RegionRequest. + """ + self.running = True + set_gdal_default_configuration() + + try: + while self.running: + logger.debug("Checking work queue for regions to process ...") + (receipt_handle, region_request_attributes) = next(self.region_requests_iter) + ThreadingLocalContextFilter.set_context(region_request_attributes) + + if region_request_attributes is not None: + try: + region_request = RegionRequest(region_request_attributes) + + if "s3:/" in region_request.image_url: + ImageRequest.validate_image_path(region_request.image_url, region_request.image_read_role) + image_path = region_request.image_url.replace("s3:/", "/vsis3", 1) + else: + image_path = region_request.image_url + + raster_dataset, sensor_model = load_gdal_dataset(image_path) + image_format = str(raster_dataset.GetDriver().ShortName).upper() + + region_request_item = self.region_request_table.get_region_request( + region_request.region_id, region_request.image_id + ) + if region_request_item is None: + region_request_item = RegionRequestItem.from_region_request(region_request) + + image_request_item = self.region_request_handler.process_region_request( + region_request, region_request_item, raster_dataset, sensor_model + ) + + if self.job_table.is_image_request_complete(image_request_item): + self.image_request_handler.complete_image_request( + region_request, image_format, raster_dataset, sensor_model + ) + + self.region_request_queue.finish_request(receipt_handle) + except RetryableJobException: + self.region_request_queue.reset_request(receipt_handle, visibility_timeout=0) + except SelfThrottledRegionException: + self.region_request_queue.reset_request( + receipt_handle, visibility_timeout=int(self.config.throttling_retry_timeout) + ) + except Exception as err: + logger.error(f"There was a problem processing the region request: {err}") + self.region_request_queue.finish_request(receipt_handle) + else: + logger.debug("Checking work queue for images to process ...") + (receipt_handle, image_request_message) = next(self.image_requests_iter) + + if image_request_message is not None: + image_request = None + try: + image_request = ImageRequest.from_external_message(image_request_message) + ThreadingLocalContextFilter.set_context(image_request.__dict__) + + if not image_request.is_valid(): + error = f"Invalid image request: {image_request_message}" + logger.exception(error) + raise InvalidImageRequestException(error) + + self.image_request_handler.process_image_request(image_request) + self.image_request_queue.finish_request(receipt_handle) + except RetryableJobException: + self.image_request_queue.reset_request(receipt_handle, visibility_timeout=0) + except Exception as err: + logger.error(f"There was a problem processing the image request: {err}") + min_image_id = image_request.image_id if image_request and image_request.image_id else "" + min_job_id = image_request.job_id if image_request and image_request.job_id else "" + minimal_job_item = JobItem( + image_id=min_image_id, + job_id=min_job_id, + processing_duration=0, + ) + self.image_request_handler.fail_image_request(minimal_job_item, err) + self.image_request_queue.finish_request(receipt_handle) + finally: + self.running = False diff --git a/src/aws/osml/model_runner/queue/request_queue.py b/src/aws/osml/model_runner/queue/request_queue.py index 5e539db5..1478e5fe 100755 --- a/src/aws/osml/model_runner/queue/request_queue.py +++ b/src/aws/osml/model_runner/queue/request_queue.py @@ -7,7 +7,7 @@ import boto3 from botocore.exceptions import ClientError -from aws.osml.model_runner.app_config import BotoConfig +from aws.osml.model_runner.config import BotoConfig class RequestQueue: diff --git a/src/aws/osml/model_runner/region_request_handler.py b/src/aws/osml/model_runner/region_request_handler.py new file mode 100644 index 00000000..0d67a7d5 --- /dev/null +++ b/src/aws/osml/model_runner/region_request_handler.py @@ -0,0 +1,207 @@ +# Copyright 2023-2024 Amazon.com, Inc. or its affiliates. + +import logging +from typing import Optional + +from aws_embedded_metrics.logger.metrics_logger import MetricsLogger +from aws_embedded_metrics.metric_scope import metric_scope +from aws_embedded_metrics.unit import Unit +from osgeo import gdal + +from aws.osml.photogrammetry import SensorModel + +from .api import RegionRequest +from .common import EndpointUtils, RequestStatus, Timer +from .config import MetricLabels, ServiceConfig +from .database import EndpointStatisticsTable, JobItem, JobTable, RegionRequestItem, RegionRequestTable +from .exceptions import ProcessRegionException, SelfThrottledRegionException +from .queue import RequestQueue +from .status import RegionStatusMonitor +from .tile_worker import TilingStrategy, process_tiles, setup_tile_workers + +# Set up logging configuration +logger = logging.getLogger(__name__) + +# GDAL 4.0 will begin using exceptions as the default; at this point the software is written to assume +# no exceptions so we call this explicitly until the software can be updated to match. +gdal.UseExceptions() + + +class RegionRequestHandler: + """ + Class responsible for handling RegionRequest processing. + """ + + def __init__( + self, + region_request_table: RegionRequestTable, + job_table: JobTable, + region_status_monitor: RegionStatusMonitor, + endpoint_statistics_table: EndpointStatisticsTable, + tiling_strategy: TilingStrategy, + region_request_queue: RequestQueue, + endpoint_utils: EndpointUtils, + config: ServiceConfig, + ) -> None: + """ + Initialize the RegionRequestHandler with the necessary dependencies. + + :param region_request_table: The table that handles region requests. + :param job_table: The job table for image/region processing. + :param region_status_monitor: A monitor to track region request status. + :param endpoint_statistics_table: Table for tracking endpoint statistics. + :param tiling_strategy: The strategy for handling image tiling. + :param region_request_queue: Queue to send region requests. + :param endpoint_utils: Utility class for handling endpoint-related operations. + :param config: Configuration settings for the service. + """ + self.region_request_table = region_request_table + self.job_table = job_table + self.region_status_monitor = region_status_monitor + self.endpoint_statistics_table = endpoint_statistics_table + self.tiling_strategy = tiling_strategy + self.region_request_queue = region_request_queue + self.endpoint_utils = endpoint_utils + self.config = config + + @metric_scope + def process_region_request( + self, + region_request: RegionRequest, + region_request_item: RegionRequestItem, + raster_dataset: gdal.Dataset, + sensor_model: Optional[SensorModel] = None, + metrics: MetricsLogger = None, + ) -> JobItem: + """ + Processes RegionRequest objects that are delegated for processing. Loads the specified region of an image into + memory to be processed by tile-workers. If a raster_dataset is not provided directly it will poll the image + from the region request. + + :param region_request: RegionRequest = the region request + :param region_request_item: RegionRequestItem = the region request to update + :param raster_dataset: gdal.Dataset = the raster dataset containing the region + :param sensor_model: Optional[SensorModel] = the sensor model for this raster dataset + :param metrics: MetricsLogger = the metrics logger to use to report metrics. + + :return: JobItem + """ + if isinstance(metrics, MetricsLogger): + metrics.set_dimensions() + + if not region_request.is_valid(): + logger.error(f"Invalid Region Request! {region_request.__dict__}") + raise ValueError("Invalid Region Request") + + if isinstance(metrics, MetricsLogger): + image_format = str(raster_dataset.GetDriver().ShortName).upper() + metrics.put_dimensions( + { + MetricLabels.OPERATION_DIMENSION: MetricLabels.REGION_PROCESSING_OPERATION, + MetricLabels.MODEL_NAME_DIMENSION: region_request.model_name, + MetricLabels.INPUT_FORMAT_DIMENSION: image_format, + } + ) + + if self.config.self_throttling: + max_regions = self.endpoint_utils.calculate_max_regions( + region_request.model_name, region_request.model_invocation_role + ) + # Add entry to the endpoint statistics table + self.endpoint_statistics_table.upsert_endpoint(region_request.model_name, max_regions) + in_progress = self.endpoint_statistics_table.current_in_progress_regions(region_request.model_name) + + if in_progress >= max_regions: + if isinstance(metrics, MetricsLogger): + metrics.put_metric(MetricLabels.THROTTLES, 1, str(Unit.COUNT.value)) + logger.warning(f"Throttling region request. (Max: {max_regions} In-progress: {in_progress}") + raise SelfThrottledRegionException + + # Increment the endpoint region counter + self.endpoint_statistics_table.increment_region_count(region_request.model_name) + + try: + with Timer( + task_str=f"Processing region {region_request.image_url} {region_request.region_bounds}", + metric_name=MetricLabels.DURATION, + logger=logger, + metrics_logger=metrics, + ): + self.region_request_table.start_region_request(region_request_item) + logging.debug( + f"Starting region request: region id: {region_request_item.region_id}" + ) + + # Set up our threaded tile worker pool + tile_queue, tile_workers = setup_tile_workers(region_request, sensor_model, self.config.elevation_model) + + # Process all our tiles + total_tile_count, failed_tile_count = process_tiles( + self.tiling_strategy, + region_request_item, + tile_queue, + tile_workers, + raster_dataset, + sensor_model, + ) + + # Update table w/ total tile counts + region_request_item.total_tiles = total_tile_count + region_request_item.succeeded_tile_count = total_tile_count - failed_tile_count + region_request_item.failed_tile_count = failed_tile_count + region_request_item = self.region_request_table.update_region_request(region_request_item) + + # Update the image request to complete this region + image_request_item = self.job_table.complete_region_request(region_request.image_id, bool(failed_tile_count)) + + # Update region request table if that region succeeded + region_status = self.region_status_monitor.get_status(region_request_item) + region_request_item = self.region_request_table.complete_region_request(region_request_item, region_status) + + self.region_status_monitor.process_event(region_request_item, region_status, "Completed region processing") + + # Write CloudWatch Metrics to the Logs + if isinstance(metrics, MetricsLogger): + metrics.put_metric(MetricLabels.INVOCATIONS, 1, str(Unit.COUNT.value)) + + # Return the updated item + return image_request_item + + except Exception as err: + failed_msg = f"Failed to process image region: {err}" + logger.error(failed_msg) + # Update the table to record the failure + region_request_item.message = failed_msg + return self.fail_region_request(region_request_item) + + finally: + # Decrement the endpoint region counter + if self.config.self_throttling: + self.endpoint_statistics_table.decrement_region_count(region_request.model_name) + + @metric_scope + def fail_region_request( + self, + region_request_item: RegionRequestItem, + metrics: MetricsLogger = None, + ) -> JobItem: + """ + Fails a region if it failed to process successfully and updates the table accordingly before + raising an exception + + :param region_request_item: RegionRequestItem = the region request to update + :param metrics: MetricsLogger = the metrics logger to use to report metrics. + + :return: None + """ + if isinstance(metrics, MetricsLogger): + metrics.put_metric(MetricLabels.ERRORS, 1, str(Unit.COUNT.value)) + try: + region_status = RequestStatus.FAILED + region_request_item = self.region_request_table.complete_region_request(region_request_item, region_status) + self.region_status_monitor.process_event(region_request_item, region_status, "Completed region processing") + return self.job_table.complete_region_request(region_request_item.image_id, error=True) + except Exception as status_error: + logger.error("Unable to update region status in job table") + logger.exception(status_error) + raise ProcessRegionException("Failed to process image region!") diff --git a/src/aws/osml/model_runner/sink/kinesis_sink.py b/src/aws/osml/model_runner/sink/kinesis_sink.py index 5f4e3732..f6739490 100755 --- a/src/aws/osml/model_runner/sink/kinesis_sink.py +++ b/src/aws/osml/model_runner/sink/kinesis_sink.py @@ -9,8 +9,8 @@ from geojson import Feature, FeatureCollection from aws.osml.model_runner.api import SinkMode, SinkType -from aws.osml.model_runner.app_config import BotoConfig, ServiceConfig from aws.osml.model_runner.common import get_credentials_for_assumed_role +from aws.osml.model_runner.config import BotoConfig, ServiceConfig from .exceptions import InvalidKinesisStreamException from .sink import Sink diff --git a/src/aws/osml/model_runner/sink/s3_sink.py b/src/aws/osml/model_runner/sink/s3_sink.py index b7e29316..4fd273af 100755 --- a/src/aws/osml/model_runner/sink/s3_sink.py +++ b/src/aws/osml/model_runner/sink/s3_sink.py @@ -12,8 +12,8 @@ from geojson import Feature, FeatureCollection from aws.osml.model_runner.api import SinkMode, SinkType -from aws.osml.model_runner.app_config import BotoConfig from aws.osml.model_runner.common import get_credentials_for_assumed_role +from aws.osml.model_runner.config import BotoConfig from .sink import Sink diff --git a/src/aws/osml/model_runner/status/sns_helper.py b/src/aws/osml/model_runner/status/sns_helper.py index 957849f1..1f155f59 100755 --- a/src/aws/osml/model_runner/status/sns_helper.py +++ b/src/aws/osml/model_runner/status/sns_helper.py @@ -5,7 +5,7 @@ import boto3 -from aws.osml.model_runner.app_config import BotoConfig +from aws.osml.model_runner.config import BotoConfig from .exceptions import SNSPublishException diff --git a/src/aws/osml/model_runner/tile_worker/tile_worker.py b/src/aws/osml/model_runner/tile_worker/tile_worker.py index 7b6badfa..2fbc8fec 100755 --- a/src/aws/osml/model_runner/tile_worker/tile_worker.py +++ b/src/aws/osml/model_runner/tile_worker/tile_worker.py @@ -14,8 +14,8 @@ from shapely.affinity import translate from aws.osml.features import Geolocator, ImagedFeaturePropertyAccessor -from aws.osml.model_runner.app_config import MetricLabels from aws.osml.model_runner.common import ThreadingLocalContextFilter, TileState, Timer +from aws.osml.model_runner.config import MetricLabels from aws.osml.model_runner.database import FeatureTable, RegionRequestTable from aws.osml.model_runner.inference import Detector diff --git a/src/aws/osml/model_runner/tile_worker/tile_worker_utils.py b/src/aws/osml/model_runner/tile_worker/tile_worker_utils.py index 750441d0..ff12aebd 100755 --- a/src/aws/osml/model_runner/tile_worker/tile_worker_utils.py +++ b/src/aws/osml/model_runner/tile_worker/tile_worker_utils.py @@ -19,8 +19,8 @@ from aws.osml.photogrammetry import ElevationModel, SensorModel from ..api import RegionRequest -from ..app_config import MetricLabels, ServiceConfig from ..common import Timer, get_credentials_for_assumed_role +from ..config import MetricLabels, ServiceConfig from ..database import FeatureTable from ..inference.endpoint_factory import FeatureDetectorFactory from .exceptions import ProcessTilesException, SetupTileWorkersException diff --git a/test/test_app.py b/test/test_app.py index 6f7bfd0d..c01dcfa5 100755 --- a/test/test_app.py +++ b/test/test_app.py @@ -452,11 +452,11 @@ def test_create_elevation_model(self) -> None: The import and reload statements are necessary to force the ServiceConfig to update with the patched environment variables. """ - import aws.osml.model_runner.app_config + import aws.osml.model_runner.config - reload(aws.osml.model_runner.app_config) + reload(aws.osml.model_runner.config) from aws.osml.gdal.gdal_dem_tile_factory import GDALDigitalElevationModelTileFactory - from aws.osml.model_runner.app_config import ServiceConfig + from aws.osml.model_runner.config import ServiceConfig from aws.osml.photogrammetry.digital_elevation_model import DigitalElevationModel from aws.osml.photogrammetry.srtm_dem_tile_set import SRTMTileSet