diff --git a/lib/idseq_utils/idseq_utils/batch_run_helpers.py b/lib/idseq_utils/idseq_utils/batch_run_helpers.py index 0aa081c1..313f93dc 100644 --- a/lib/idseq_utils/idseq_utils/batch_run_helpers.py +++ b/lib/idseq_utils/idseq_utils/batch_run_helpers.py @@ -1,3 +1,4 @@ +import hashlib import json import logging import os @@ -9,7 +10,7 @@ from os import listdir from multiprocessing import Pool from subprocess import run -from typing import Dict, List +from typing import Dict, List, Optional from urllib.parse import urlparse from idseq_utils.diamond_scatter import blastx_join @@ -19,6 +20,11 @@ from botocore.exceptions import ClientError from botocore.config import Config +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', +) log = logging.getLogger(__name__) MAX_CHUNKS_IN_FLIGHT = 30 # TODO: remove this constant, currently does nothing since we have at most 30 index chunks @@ -83,25 +89,62 @@ def _get_job_status(job_id, use_batch_api=False): raise e +class BatchJobCache: + """ + BatchJobCache saves job IDs so the coordinator can re-attach to running batch jobs when the coordinator fails + + The output should always be the same if the inputs are the same, however we also incorporate the batch_args + into the cache because a retry on spot vs on demand will result in a different batch queue. + """ + def __init__(self, bucket: str, prefix: str, inputs: Dict[str, str]): + self.bucket = bucket + self.prefix = prefix + self.inputs = inputs + + def _key(self, batch_args: Dict) -> str: + hash = hashlib.sha256() + cache_dict = {"inputs": self.inputs, "batch_args": batch_args} + hash.update(json.dumps(cache_dict, sort_keys=True).encode()) + return os.path.join(self.prefix, hash.hexdigest()) + + def get(self, batch_args: Dict) -> Optional[str]: + try: + resp = _s3_client.get_object(Bucket=self.bucket, Key=self._key(batch_args)) + return resp["Body"].read().decode() + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + return None + else: + raise e + + def put(self, batch_args: Dict, job_id: str): + _s3_client.put_object( + Bucket=self.bucket, + Key=self._key(batch_args), + Body=job_id.encode(), + Tagging="AlignmentCoordination=True", + ) + + def _run_batch_job( job_name: str, job_queue: str, job_definition: str, environment: Dict[str, str], retries: int, + cache: BatchJobCache, ): - response = _batch_client.submit_job( - jobName=job_name, - jobQueue=job_queue, - jobDefinition=job_definition, - containerOverrides={ + submit_args = { + "jobName": job_name, + "jobQueue": job_queue, + "jobDefinition": job_definition, + "containerOverrides": { "environment": [{"name": k, "value": v} for k, v in environment.items()], "memory": 130816, "vcpus": 24, }, - retryStrategy={"attempts": retries}, - ) - job_id = response["jobId"] + "retryStrategy": {"attempts": retries}, + } def _log_status(status: str): level = logging.INFO if status != "FAILED" else logging.ERROR @@ -119,7 +162,14 @@ def _log_status(status: str): ), ) - _log_status("SUBMITTED") + job_id = cache.get(submit_args) + if job_id: + log.info(f"reattach to batch job: {job_id}") + else: + response = _batch_client.submit_job(**submit_args) + job_id = response["jobId"] + cache.put(submit_args, job_id) + _log_status("SUBMITTED") delay = 60 + random.randint( -60 // 2, 60 // 2 @@ -194,26 +244,19 @@ def _job_queue(provisioning_model: str): input_bucket, input_key = _bucket_and_key(wdl_input_uri) wdl_output_uri = os.path.join(chunk_dir, f"{chunk_id}-output.json") - output_bucket, output_key = _bucket_and_key(wdl_output_uri) wdl_workflow_uri = f"s3://idseq-workflows/{aligner}-{aligner_wdl_version}/{aligner}.wdl" - # if this job fails we don't want to re-run chunks that have already been processed - # the presence of the output file means the chunk has already been processed - try: - _s3_client.head_object(Bucket=output_bucket, Key=output_key) - log.info(f"skipping chunk, output already exists: {wdl_output_uri}") - return - except ClientError as e: - # raise the error if it is anything other than "not found" - if e.response["Error"]["Code"] != "404": - raise e + cache_prefix_uri = os.path.join(chunk_dir, "batch_job_cache/") + cache_bucket, cache_prefix = _bucket_and_key(cache_prefix_uri) + cache = BatchJobCache(cache_bucket, cache_prefix, inputs) _s3_client.put_object( Bucket=input_bucket, Key=input_key, Body=json.dumps(inputs).encode(), ContentType="application/json", + Tagging="AlignmentCoordination=True", ) environment = { @@ -231,6 +274,7 @@ def _job_queue(provisioning_model: str): job_definition=job_definition, environment=environment, retries=2, + cache=cache, ) except BatchJobFailed: _run_batch_job( @@ -239,6 +283,7 @@ def _job_queue(provisioning_model: str): job_definition=job_definition, environment=environment, retries=1, + cache=cache, ) @@ -263,6 +308,7 @@ def run_alignment( ): bucket, prefix = _bucket_and_key(db_path) chunk_dir = os.path.join(input_dir, f"{aligner}-chunks") + _, chunk_prefix = _bucket_and_key(chunk_dir) chunks = ( [ input_dir, @@ -281,9 +327,16 @@ def run_alignment( run(["s3parcp", "--recursive", chunk_dir, "chunks"], check=True) if os.path.exists(os.path.join("chunks", "cache")): shutil.rmtree(os.path.join("chunks", "cache")) + if os.path.exists(os.path.join("chunks", "batch_job_cache")): + shutil.rmtree(os.path.join("chunks", "batch_job_cache")) for fn in listdir("chunks"): if fn.endswith("json"): os.remove(os.path.join("chunks", fn)) + _s3_client.put_object_tagging( + Bucket=bucket, + Key=os.path.join(chunk_prefix, fn), + Tagging={"TagSet": [{"Key": "AlignmentCoordination", "Value": "True"}]}, + ) if aligner == "diamond": blastx_join("chunks", result_path, aligner_args, *queries) else: