From 22193c4f8d7f4c83d191e2e7d0e527a4343840e6 Mon Sep 17 00:00:00 2001 From: Phil Snyder Date: Fri, 17 May 2024 15:46:38 -0700 Subject: [PATCH] Add `filter_object_info` function to dispatch Lambda --- src/lambda_function/dispatch/app.py | 53 +++++++++++++++++++++++++++-- tests/test_lambda_dispatch.py | 53 +++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 2 deletions(-) diff --git a/src/lambda_function/dispatch/app.py b/src/lambda_function/dispatch/app.py index c25821fd..35f14930 100644 --- a/src/lambda_function/dispatch/app.py +++ b/src/lambda_function/dispatch/app.py @@ -9,6 +9,7 @@ import logging import os import zipfile +from typing import Optional # use | for type hints in 3.10+ from urllib import parse import boto3 @@ -16,6 +17,49 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) +def filter_object_info(object_info: dict) -> Optional[dict]: + """ + Filter out objects that should not be processed. + + Returns None for: + + - Records containing owner.txt + - Records that don't contain a specific object key like / + - Records that are missing the `Key` field. + - Records that are missing the `Bucket` field. + + Args: + object_info (dict): Object information from source S3 bucket + as formatted by `get_object_info`. + + Returns: + dict: `object_info` if it passes the filter criteria (i.e., acts as + identity function) otherwise returns None. + """ + if not object_info["Key"]: + logger.info( + "This object_info record doesn't contain a source key " + f"and can't be processed.\nMessage: {object_info}", + ) + return None + elif not object_info["Bucket"]: + logger.info( + "This object_info record doesn't contain a source bucket " + f"and can't be processed.\nMessage: {object_info}", + ) + return None + elif "owner.txt" in object_info["Key"]: + logger.info( + f"This object_info record is an owner.txt and can't be processed.\nMessage: {object_info}" + ) + return None + elif object_info["Key"].endswith("/"): + logger.info( + f"This object_info record is a directory and can't be processed.\nMessage: {object_info}" + ) + return None + return object_info + def get_object_info(s3_event: dict) -> dict: """ Derive object info from an S3 event. @@ -124,8 +168,13 @@ def main( sns_notification = json.loads(sqs_record["body"]) sns_message = json.loads(sns_notification["Message"]) logger.info(f"Received SNS message: {sns_message}") - for s3_event in sns_message["Records"]: - object_info = get_object_info(s3_event) + all_object_info_list = map(get_object_info, sns_message["Records"]) + valid_object_info_list = [ + object_info + for object_info in all_object_info_list + if filter_object_info(object_info) is not None + ] + for object_info in valid_object_info_list: s3_client.download_file(Filename=temp_zip_path, **object_info) logger.info(f"Getting archive contents for {object_info}") archive_contents = get_archive_contents( diff --git a/tests/test_lambda_dispatch.py b/tests/test_lambda_dispatch.py index 50b8655f..5de358a5 100644 --- a/tests/test_lambda_dispatch.py +++ b/tests/test_lambda_dispatch.py @@ -135,6 +135,59 @@ def test_get_object_info_unicode_characters_in_key(s3_event): assert object_info["Key"] == \ "main/2023-09-26T00:06:39Z_d873eafb-554f-4f8a-9e61-cdbcb7de07eb" +@pytest.mark.parametrize( + "object_info,expected", + [ + ( + { + "Bucket": "recover-dev-input-data", + "Key": "main/2023-01-12T22--02--17Z_77fefff8-b0e2-4c1b-b0c5-405554c92368", + }, + { + "Bucket": "recover-dev-input-data", + "Key": "main/2023-01-12T22--02--17Z_77fefff8-b0e2-4c1b-b0c5-405554c92368", + }, + ), + ( + { + "Bucket": "recover-dev-input-data", + "Key": "main/v1/owner.txt", + }, + None, + ), + ( + { + "Bucket": "recover-dev-input-data", + "Key": "main/adults_v2/", + }, + None, + ), + ( + { + "Bucket": "recover-dev-input-data", + "Key": None, + }, + None, + ), + ( + { + "Bucket": None, + "Key": "main/2023-01-12T22--02--17Z_77fefff8-b0e2-4c1b-b0c5-405554c92368", + }, + None, + ), + ], + ids=[ + "correct_msg_format", + "owner_txt", + "directory", + "missing_key", + "missing_bucket", + ], +) +def test_that_filter_object_info_returns_expected_result(object_info, expected): + assert app.filter_object_info(object_info) == expected + def test_get_archive_contents(archive_path, archive_json_paths): dummy_bucket = "dummy_bucket" dummy_key = "dummy_key"