diff --git a/tests/serialize/runstate/dynamodb_state_store_test.py b/tests/serialize/runstate/dynamodb_state_store_test.py index 687b331bd..e7ad7a1ec 100644 --- a/tests/serialize/runstate/dynamodb_state_store_test.py +++ b/tests/serialize/runstate/dynamodb_state_store_test.py @@ -9,6 +9,7 @@ from testifycompat import assert_equal from tron.serialize.runstate.dynamodb_state_store import DynamoDBStateStore +from tron.serialize.runstate.dynamodb_state_store import MAX_UNPROCESSED_KEYS_RETRIES def mock_transact_write_items(self): @@ -294,58 +295,80 @@ def test_delete_item_with_json_partitions(self, store, small_object, large_objec vals = store.restore([key]) assert key not in vals - def test_retry_saving(self, store, small_object, large_object): - with mock.patch( - "moto.dynamodb2.responses.DynamoHandler.transact_write_items", - side_effect=KeyError("foo"), - ) as mock_failed_write: - keys = [store.build_key("job_state", i) for i in range(1)] - value = small_object - pairs = zip(keys, (value for i in range(len(keys)))) - try: - store.save(pairs) - except Exception: - assert_equal(mock_failed_write.call_count, 3) - - def test_retry_reading(self, store, small_object, large_object): + @pytest.mark.parametrize( + "test_object, side_effects, expected_save_errors, expected_queue_length", + [ + # All attempts fail + ("small_object", [KeyError("foo")] * 3, 3, 1), + ("large_object", [KeyError("foo")] * 3, 3, 1), + # Failure followed by success + ("small_object", [KeyError("foo"), {}], 0, 0), + ("large_object", [KeyError("foo"), {}], 0, 0), + ], + ) + def test_retry_saving( + self, test_object, side_effects, expected_save_errors, expected_queue_length, store, small_object, large_object + ): + object_mapping = { + "small_object": small_object, + "large_object": large_object, + } + value = object_mapping[test_object] + + with mock.patch.object( + store.client, + "transact_write_items", + side_effect=side_effects, + ) as mock_transact_write: + keys = [store.build_key("job_state", 0)] + pairs = zip(keys, [value]) + store.save(pairs) + + for _ in side_effects: + store._consume_save_queue() + + assert mock_transact_write.call_count == len(side_effects) + assert store.save_errors == expected_save_errors + assert len(store.save_queue) == expected_queue_length + + @pytest.mark.parametrize( + "attempt, expected_delay", + [ + (1, 1), + (2, 2), + (3, 4), + (4, 8), + (5, 10), + (6, 10), + (7, 10), + ], + ) + def test_calculate_backoff_delay(self, store, attempt, expected_delay): + delay = store._calculate_backoff_delay(attempt) + assert_equal(delay, expected_delay) + + def test_retry_reading(self, store): unprocessed_value = { - "Responses": { - store.name: [ - { - "index": {"N": "0"}, - "key": {"S": "job_state 0"}, - }, - ], - }, + "Responses": {}, "UnprocessedKeys": { store.name: { + "Keys": [{"key": {"S": store.build_key("job_state", 0)}, "index": {"N": "0"}}], "ConsistentRead": True, - "Keys": [ - { - "index": {"N": "0"}, - "key": {"S": "job_state 0"}, - } - ], - }, + } }, - "ResponseMetadata": {}, } - keys = [store.build_key("job_state", i) for i in range(1)] - value = small_object - pairs = zip(keys, (value for i in range(len(keys)))) - store.save(pairs) + + keys = [store.build_key("job_state", 0)] + with mock.patch.object( store.client, "batch_get_item", return_value=unprocessed_value, - ) as mock_failed_read: - try: - with mock.patch("tron.config.static_config.load_yaml_file", autospec=True), mock.patch( - "tron.config.static_config.build_configuration_watcher", autospec=True - ): - store.restore(keys) - except Exception: - assert_equal(mock_failed_read.call_count, 11) + ) as mock_batch_get_item, mock.patch("time.sleep") as mock_sleep, pytest.raises(Exception) as exec_info: + store.restore(keys) + assert "failed to retrieve items with keys" in str(exec_info.value) + assert mock_batch_get_item.call_count == MAX_UNPROCESSED_KEYS_RETRIES + assert mock_sleep.call_count == MAX_UNPROCESSED_KEYS_RETRIES def test_restore_exception_propagation(self, store, small_object): # This test is to ensure that restore propagates exceptions upwards: see DAR-2328 diff --git a/tron/serialize/runstate/dynamodb_state_store.py b/tron/serialize/runstate/dynamodb_state_store.py index 35ea52da3..ffaee7e47 100644 --- a/tron/serialize/runstate/dynamodb_state_store.py +++ b/tron/serialize/runstate/dynamodb_state_store.py @@ -20,6 +20,8 @@ from typing import TypeVar import boto3 # type: ignore +import botocore # type: ignore +from botocore.config import Config # type: ignore import tron.prom_metrics as prom_metrics from tron.core.job import Job @@ -35,7 +37,10 @@ # to contain other attributes like object name and number of partitions. OBJECT_SIZE = 200_000 # TODO: TRON-2240 - consider swapping back to 400_000 now that we've removed pickles MAX_SAVE_QUEUE = 500 -MAX_ATTEMPTS = 10 +# This is distinct from the number of retries in the retry_config as this is used for handling unprocessed +# keys outside the bounds of something like retrying on a ThrottlingException. We need this limit to avoid +# infinite loops in the case where a key is truly unprocessable. +MAX_UNPROCESSED_KEYS_RETRIES = 10 MAX_TRANSACT_WRITE_ITEMS = 100 log = logging.getLogger(__name__) T = TypeVar("T") @@ -43,8 +48,22 @@ class DynamoDBStateStore: def __init__(self, name, dynamodb_region, stopping=False) -> None: - self.dynamodb = boto3.resource("dynamodb", region_name=dynamodb_region) - self.client = boto3.client("dynamodb", region_name=dynamodb_region) + # Standard mode includes an exponential backoff by a base factor of 2 for a + # maximum backoff time of 20 seconds (min(b*r^i, MAX_BACKOFF) where b is a + # random number between 0 and 1 and r is the base factor of 2). This might + # look like: + # + # seconds_to_sleep = min(1 × 2^1, 20) = min(2, 20) = 2 seconds + # + # By our 5th retry (2^5 is 32) we will be sleeping *up to* 20 seconds, depending + # on the random jitter. + # + # It handles transient errors like RequestTimeout and ConnectionError, as well + # as Service-side errors like Throttling, SlowDown, and LimitExceeded. + retry_config = Config(retries={"max_attempts": 5, "mode": "standard"}) + + self.dynamodb = boto3.resource("dynamodb", region_name=dynamodb_region, config=retry_config) + self.client = boto3.client("dynamodb", region_name=dynamodb_region, config=retry_config) self.name = name self.dynamodb_region = dynamodb_region self.table = self.dynamodb.Table(name) @@ -63,11 +82,11 @@ def build_key(self, type, iden) -> str: def restore(self, keys, read_json: bool = False) -> dict: """ - Fetch all under the same parition key(s). + Fetch all under the same partition key(s). ret: """ # format of the keys always passed here is - # job_state job_name --> high level info about the job: enabled, run_nums + # job_state job_name --> high level info about the job: enabled, run_nums # job_run_state job_run_name --> high level info about the job run first_items = self._get_first_partitions(keys) remaining_items = self._get_remaining_partitions(first_items, read_json) @@ -83,12 +102,19 @@ def chunk_keys(self, keys: Sequence[T]) -> List[Sequence[T]]: cand_keys_chunks.append(keys[i : min(len(keys), i + 100)]) return cand_keys_chunks + def _calculate_backoff_delay(self, attempt: int) -> int: + base_delay_seconds = 1 + max_delay_seconds = 10 + delay: int = min(base_delay_seconds * (2 ** (attempt - 1)), max_delay_seconds) + return delay + def _get_items(self, table_keys: list) -> object: items = [] # let's avoid potentially mutating our input :) cand_keys_list = copy.copy(table_keys) - attempts_to_retrieve_keys = 0 - while len(cand_keys_list) != 0: + attempts = 0 + + while len(cand_keys_list) != 0 and attempts < MAX_UNPROCESSED_KEYS_RETRIES: with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: responses = [ executor.submit( @@ -106,20 +132,33 @@ def _get_items(self, table_keys: list) -> object: cand_keys_list = [] for resp in concurrent.futures.as_completed(responses): try: - items.extend(resp.result()["Responses"][self.name]) - # add any potential unprocessed keys to the thread pool - if resp.result()["UnprocessedKeys"].get(self.name) and attempts_to_retrieve_keys < MAX_ATTEMPTS: - cand_keys_list.extend(resp.result()["UnprocessedKeys"][self.name]["Keys"]) - elif attempts_to_retrieve_keys >= MAX_ATTEMPTS: - failed_keys = resp.result()["UnprocessedKeys"][self.name]["Keys"] - error = Exception( - f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{failed_keys}\n from dynamodb\n{resp.result()}" - ) - raise error - except Exception as e: + result = resp.result() + items.extend(result.get("Responses", {}).get(self.name, [])) + + # If DynamoDB returns unprocessed keys, we need to collect them and retry + unprocessed_keys = result.get("UnprocessedKeys", {}).get(self.name, {}).get("Keys", []) + if unprocessed_keys: + cand_keys_list.extend(unprocessed_keys) + except botocore.exceptions.ClientError as e: + log.exception(f"ClientError during batch_get_item: {e.response}") + raise + except Exception: log.exception("Encountered issues retrieving data from DynamoDB") - raise e - attempts_to_retrieve_keys += 1 + raise + if cand_keys_list: + attempts += 1 + delay = self._calculate_backoff_delay(attempts) + log.warning( + f"Attempt {attempts}/{MAX_UNPROCESSED_KEYS_RETRIES} - " + f"Retrying {len(cand_keys_list)} unprocessed keys after {delay}s delay." + ) + time.sleep(delay) + if cand_keys_list: + error = Exception( + f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{cand_keys_list}\n from dynamodb after {MAX_UNPROCESSED_KEYS_RETRIES} retries." + ) + log.error(error) + raise error return items def _get_first_partitions(self, keys: list): @@ -337,25 +376,21 @@ def __setitem__(self, key: str, value: Tuple[bytes, str]) -> None: "N": str(num_json_val_partitions), } - count = 0 items.append(item) - while len(items) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1: + # We want to write the items when we've either reached the max number of items + # for a transaction, or when we're done processing all partitions + if len(items) == MAX_TRANSACT_WRITE_ITEMS or index == max_partitions - 1: try: self.client.transact_write_items(TransactItems=items) items = [] - break # exit the while loop on successful writing except Exception as e: - count += 1 - if count > 3: - timer( - name="tron.dynamodb.setitem", - delta=time.time() - start, - ) - log.error(f"Failed to save partition for key: {key}, error: {repr(e)}") - raise e - else: - log.warning(f"Got error while saving {key}, trying again: {repr(e)}") + timer( + name="tron.dynamodb.setitem", + delta=time.time() - start, + ) + log.error(f"Failed to save partition for key: {key}, error: {repr(e)}") + raise e timer( name="tron.dynamodb.setitem", delta=time.time() - start,