Skip to content

Commit

Permalink
Parallelizing execution of restoring in Tron
Browse files Browse the repository at this point in the history
  • Loading branch information
EmanElsaban committed Apr 17, 2024
1 parent abd245f commit a5c94e1
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 34 deletions.
4 changes: 2 additions & 2 deletions tron/core/job_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def update(self, new_job_scheduler):

def restore_state(self, job_state_data, config_action_runner):
"""
Loops through the jobs and their runs in order to restore
state for each run. As we restore state, we will also schedule the next
Loops through the jobs and their runs in order to load their
state for each run. As we load the state, we will also schedule the next
runs for each job
"""
for name, state in job_state_data.items():
Expand Down
2 changes: 1 addition & 1 deletion tron/core/job_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, job):
self.watch(job)

def restore_state(self, job_state_data, config_action_runner):
"""Restore the job state and schedule any JobRuns."""
"""Load the job state and schedule any JobRuns."""
job_runs = self.job.get_job_runs_from_state(job_state_data)
for run in job_runs:
self.job.watch(run)
Expand Down
13 changes: 12 additions & 1 deletion tron/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,14 @@ def initial_setup(self):
# The job schedule factories will be created in the function below
self._load_config()
# Jobs will also get scheduled (internally) once the state for action runs are restored in restore_state
start_restore_state = time.time()
self.restore_state(
actioncommand.create_action_runner_factory_from_config(
self.config.load().get_master().action_runner,
),
)
duration_restore_state = time.time() - start_restore_state
log.info(f"Execution time for restore_state function: {duration_restore_state}")
# Any job with existing state would have been scheduled already. Jobs
# without any state will be scheduled here.
self.jobs.run_queue_schedule()
Expand Down Expand Up @@ -162,16 +165,24 @@ def get_config_manager(self):
return self.config

def restore_state(self, action_runner):
"""Use the state manager to retrieve to persisted state and apply it
"""Use the state manager to retrieve the persisted state from dynamodb and apply it
to the configured Jobs.
"""
log.info("Restoring from DynamoDB")
start_time_restore_dynamodb = time.time()
# restores the state of the jobs and their runs from DynamoDB
states = self.state_watcher.restore(self.jobs.get_names())
duration_restore = time.time() - start_time_restore_dynamodb
log.info(f"Time takes to state_watcher restore state directly from dynamodb: {duration_restore}")
MesosClusterRepository.restore_state(states.get("mesos_state", {}))
log.info(
f"Tron will start restoring state for the jobs and will start scheduling them! Time elapsed since Tron started {time.time() - self.boot_time}"
)
# loads the runs' state and schedule the next run for each job
start_time_load_state = time.time()
self.jobs.restore_state(states.get("job_state", {}), action_runner)
duration_load = time.time() - start_time_load_state
log.info(f"Time taken in jobs.restore_state: {duration_load}")
log.info(
f"Tron completed restoring state for the jobs. Time elapsed since Tron started {time.time() - self.boot_time}"
)
Expand Down
51 changes: 28 additions & 23 deletions tron/serialize/runstate/dynamodb_state_store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import concurrent.futures
import logging
import math
import os
Expand Down Expand Up @@ -47,38 +48,42 @@ def restore(self, keys) -> dict:
vals = self._merge_items(first_items, remaining_items)
return vals

def _get_items(self, keys: list) -> object:
items = []
def chunk_keys(self, keys: list) -> list:
"""Generates the cand keys list to be used to read from DynamoDB"""
# have a for loop here for all the key chunks we want to go over
cand_keys_chunks = []
for i in range(0, len(keys), 100):
count = 0
cand_keys = keys[i : min(len(keys), i + 100)]
while True:
resp = self.client.batch_get_item(
RequestItems={
self.name: {
"Keys": cand_keys,
"ConsistentRead": True,
},
},
)
items.extend(resp["Responses"][self.name])
if resp["UnprocessedKeys"].get(self.name) and count < 10:
cand_keys = resp["UnprocessedKeys"][self.name]["Keys"]
count += 1
elif count >= 10:
error = Exception(
f"tron_dynamodb_restore_failure: failed to retrieve items with keys \n{cand_keys}\n from dynamodb\n{resp}"
# chunks of 100 keys will be in this list
cand_keys_chunks.append(keys[i : min(len(keys), i + 100)])
return cand_keys_chunks

def _get_items(self, table_keys: list) -> object:
items = []
# precompute the cand_keys and then all we gotta do is submit stuff to the thread pool using the precomputed keys
cand_keys_list = self.chunk_keys(table_keys)
count = 0
while count < len(cand_keys_list):
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
responses = [
executor.submit(
self.client.batch_get_item, RequestItems={self.name: {"Keys": key, "ConsistentRead": True}}
)
raise error
else:
break
for key in cand_keys_list
]
for resp in concurrent.futures.as_completed(responses):
items.extend(resp.result()["Responses"][self.name])
# add any potential unprocessed keys to the thread pool
if resp.result()["UnprocessedKeys"].get(self.name):
cand_keys_list.append(resp.result()["UnprocessedKeys"][self.name]["Keys"])
count += 1
return items

def _get_first_partitions(self, keys: list):
new_keys = [{"key": {"S": key}, "index": {"N": "0"}} for key in keys]
return self._get_items(new_keys)

def _get_remaining_partitions(self, items: list):
"""Get items in the remaining partitions: N = 1 and beyond"""
keys_for_remaining_items = []
for item in items:
remaining_items = [
Expand Down
29 changes: 22 additions & 7 deletions tron/serialize/runstate/statemanager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import concurrent.futures
import itertools
import logging
import time
Expand All @@ -14,6 +15,8 @@
from tron.serialize.runstate.yamlstore import YamlStateStore
from tron.utils import observer

# import threading

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -135,6 +138,7 @@ def __init__(self, persistence_impl, buffer):
self.enabled = True
self._buffer = buffer
self._impl = persistence_impl
# self._lock = threading.Lock()
self.metadata_key = self._impl.build_key(
runstate.MCP_STATE,
StateMetadata.name,
Expand All @@ -146,11 +150,21 @@ def restore(self, job_names, skip_validation=False):
if not skip_validation:
self._restore_metadata()

# First, restore the jobs themselves
jobs = self._restore_dicts(runstate.JOB_STATE, job_names)
# jobs should be a dictionary that contains job name and number of runs
# {'MASTER.k8s': {'run_nums':[0], 'enabled': True}, 'MASTER.cits_test_frequent_1': {'run_nums': [1,0], 'enabled': True}}
for job_name, job_state in jobs.items():
job_state["runs"] = self._restore_runs_for_job(job_name, job_state)

# second, restore the runs for each of the jobs restored above
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
# start the threads and mark each future with it's job name
# this is useful so that we can index the job name later to add the runs to the jobs dictionary
results = {
executor.submit(self._restore_runs_for_job, job_name, job_state): job_name
for job_name, job_state in jobs.items()
}
for result in concurrent.futures.as_completed(results):
jobs[results[result]]["runs"] = result.result()
frameworks = self._restore_dicts(runstate.MESOS_STATE, ["frameworks"])

state = {
Expand All @@ -160,15 +174,16 @@ def restore(self, job_names, skip_validation=False):
return state

def _restore_runs_for_job(self, job_name, job_state):
"""Restore the state for the runs of each job"""
run_nums = job_state["run_nums"]
runs = []
keys_ids_list = []
# with self._lock:
for run_num in run_nums:
key = jobrun.get_job_run_id(job_name, run_num)
run_state = list(self._restore_dicts(runstate.JOB_RUN_STATE, [key]).values())
if not run_state:
log.error(f"Failed to restore {key}, no state found for it")
else:
runs.append(run_state[0])
keys_ids_list.append(key)
run_state = list(self._restore_dicts(runstate.JOB_RUN_STATE, keys_ids_list).values())
runs.extend(run_state)
return runs

def _restore_metadata(self):
Expand Down

0 comments on commit a5c94e1

Please sign in to comment.