From 501681bd968743eeea847d071dc147bd1181555b Mon Sep 17 00:00:00 2001 From: VJalili Date: Fri, 5 Apr 2024 15:41:39 -0400 Subject: [PATCH] Implement both downsampler and converter & implement a base type that both derives from. --- tests/utilities/downsamplers.py | 95 ++++++++++++++++++++++----- tests/utilities/generate_test_data.py | 71 +++++++++++++------- 2 files changed, 125 insertions(+), 41 deletions(-) diff --git a/tests/utilities/downsamplers.py b/tests/utilities/downsamplers.py index 761a69179..c13308730 100644 --- a/tests/utilities/downsamplers.py +++ b/tests/utilities/downsamplers.py @@ -1,4 +1,6 @@ +import os import pysam +import subprocess from typing import Callable, List from dataclasses import dataclass @@ -12,15 +14,60 @@ class Region: end: int -class BaseDownsampler: +class BaseTransformer: def __init__(self, working_dir: str, callback: Callable[[str, ...], dict]): # Convert the string to an ABS path and make sure it exists. self.working_dir = Path(working_dir).resolve(strict=True) self.callback = callback + @staticmethod + def get_supported_file_types() -> List[str]: + """ + The file types should include the '.' prefix to match with Path().suffix output + (e.g., it should return '.cram' instead of 'cram'). + """ + raise NotImplementedError() + def get_output_filename(self, input_filename, output_prefix): return str(self.working_dir.joinpath(f"{output_prefix}{Path(input_filename).name}")) + +class BaseConverter(BaseTransformer): + def __init__(self, working_dir: str, callback: Callable[[str, ...], dict]): + super().__init__(working_dir, callback) + + @staticmethod + def get_supported_file_types() -> List[str]: + raise NotImplementedError() + + def convert(self, input_filename: str, output_prefix: str) -> dict: + raise NotImplementedError() + + +class BedToIntervalListConverter(BaseConverter): + def __init__(self, working_dir, callback: Callable[[str], dict], sequence_dict_filename: str, picard_path: str, **kwargs): + super().__init__(working_dir, callback) + self.sequence_dict_filename = sequence_dict_filename + self.picard_path = picard_path + + @staticmethod + def get_supported_file_types() -> List[str]: + return [".interval_list"] + + def convert(self, input_filename: str, output_prefix: str) -> dict: + output_filename = self.get_output_filename(input_filename, output_prefix) + subprocess.run( + ["java", "-jar", self.picard_path, "BedToIntervalList", + "-I", input_filename, "-O", output_filename, "-SD", self.sequence_dict_filename], + check=True) + return self.callback(output_filename) + + + +class BaseDownsampler(BaseTransformer): + def __init__(self, working_dir: str, callback: Callable[[str, ...], dict]): + super().__init__(working_dir, callback) + @staticmethod def get_supported_file_types() -> List[str]: """ @@ -34,30 +81,44 @@ def downsample(self, input_filename: str, output_prefix: str, regions: List[Regi class CramDownsampler(BaseDownsampler): - def __init__(self, working_dir, callback: Callable[[str, str], dict]): + def __init__(self, working_dir, callback: Callable[[str, str], dict], reference_fasta: str, reference_index: str, **kwargs): super().__init__(working_dir, callback) + self.reference_fasta = reference_fasta + self.reference_index = reference_index @staticmethod def get_supported_file_types() -> List[str]: return [".cram"] def downsample(self, input_filename: str, output_prefix: str, regions: List[Region]) -> dict: - output_filename = self.get_output_filename(input_filename, output_prefix) - with \ - pysam.AlignmentFile(input_filename, "rc") as input_cram_file, \ - pysam.AlignmentFile(output_filename, "wc", - header=input_cram_file.header, - reference_names=input_cram_file.references) as output_cram_file: + # Implementation notes: + # 1. This method calls `samtools` instead of using `pysam` since it needs to include + # distant pair-end reads in the downsampled regions (i.e., the `--fetch-pairs` flag). + # Such reads are needed for MELT to function properly. + # + # 2. The method writes target regions to a BED file. While taking the BED file containing + # the regions as input seems a better option, the current approach is implemented as + # taking a BED file as input instead of a list of regions would convolut the method signature. + + regions_filename = os.path.join(self.working_dir, "_tmp_regions.bed") + with open(regions_filename, "w") as f: for region in regions: - for read in input_cram_file.fetch(region=f"{region.chr}:{region.start}-{region.end}"): - output_cram_file.write(read) - index_filename = f"{output_filename}.crai" - pysam.index(output_filename, index_filename) - return self.callback(output_filename, index_filename) + f.write("\t".join([str(region.chr), str(region.start), str(region.end)]) + "\n") + + output_filename = self.get_output_filename(input_filename, output_prefix) + subprocess.run( + ["samtools", "view", input_filename, "--reference", self.reference_fasta, + "--targets-file", regions_filename, "--output", output_filename, "--cram", "--fetch-pairs"] + ) + + subprocess.run(["samtools", "index", output_filename]) + + os.remove(regions_filename) + return self.callback(output_filename, output_filename + ".fai") class VcfDownsampler(BaseDownsampler): - def __init__(self, working_dir, callback: Callable[[str], dict]): + def __init__(self, working_dir, callback: Callable[[str], dict], **kwargs): super().__init__(working_dir, callback) @staticmethod @@ -89,7 +150,7 @@ class IntervalListDownsampler(BaseDownsampler): # 1. It needs the picard tool installed; # 2. It needs an additional "SD" input, which would make the `downsample` method signature complicated. - def __init__(self, working_dir, callback: Callable[[str], dict]): + def __init__(self, working_dir, callback: Callable[[str], dict], **kwargs): super().__init__(working_dir, callback) @staticmethod @@ -114,7 +175,7 @@ def downsample(self, input_filename: str, output_prefix: str, regions: List[Regi class BedDownsampler(BaseDownsampler): - def __init__(self, working_dir, callback: Callable[[str], dict]): + def __init__(self, working_dir, callback: Callable[[str], dict], **kwargs): super().__init__(working_dir, callback) @staticmethod @@ -136,7 +197,7 @@ def downsample(self, input_filename: str, output_prefix: str, regions: List[Regi class PrimaryContigsDownsampler(BaseDownsampler): - def __init__(self, working_dir, callback: Callable[[str], dict], delimiter: str = "\t"): + def __init__(self, working_dir, callback: Callable[[str], dict], delimiter: str = "\t", **kwargs): super().__init__(working_dir, callback) self.delimiter = delimiter diff --git a/tests/utilities/generate_test_data.py b/tests/utilities/generate_test_data.py index b527043dc..853d9949f 100644 --- a/tests/utilities/generate_test_data.py +++ b/tests/utilities/generate_test_data.py @@ -22,7 +22,7 @@ class Region: @dataclass class Handler: - downsampler: Union[downsamplers.BaseDownsampler, Type[downsamplers.BaseDownsampler]] + transformer: Union[downsamplers.BaseTransformer, Type[downsamplers.BaseTransformer]] callback: Callable[[str, ...], dict] @@ -32,26 +32,26 @@ class Handler: downsamplers.CramDownsampler, lambda cram, index: {"bam_or_cram_file": cram, "bam_or_cram_index": index} ), - "preprocessed_intervals": Handler( - downsamplers.IntervalListDownsampler, - lambda x: {"preprocessed_intervals": x, "melt_metrics_intervals": x} - ), - "sd_locs_vcf": Handler( - downsamplers.VcfDownsampler, - lambda x: {"sd_locs_vcf": x} - ), - "primary_contigs_list": Handler( - downsamplers.PrimaryContigsDownsampler, - lambda x: {"primary_contigs_list": x} - ), - "primary_contigs_fai": Handler( - downsamplers.PrimaryContigsDownsampler, - lambda x: {"primary_contigs_fai": x} - ), - "wham_include_list_bed_file": Handler( - downsamplers.BedDownsampler, - lambda x: {"wham_include_list_bed_file": x} - ) + # "preprocessed_intervals": Handler( + # downsamplers.BedToIntervalListConverter, + # lambda x: {"preprocessed_intervals": x, "melt_metrics_intervals": x} + # ), + # "sd_locs_vcf": Handler( + # downsamplers.VcfDownsampler, + # lambda x: {"sd_locs_vcf": x} + # ), + # "primary_contigs_list": Handler( + # downsamplers.PrimaryContigsDownsampler, + # lambda x: {"primary_contigs_list": x} + # ), + # "primary_contigs_fai": Handler( + # downsamplers.PrimaryContigsDownsampler, + # lambda x: {"primary_contigs_fai": x} + # ), + # "wham_include_list_bed_file": Handler( + # downsamplers.BedDownsampler, + # lambda x: {"wham_include_list_bed_file": x} + # ) } } @@ -83,7 +83,17 @@ def localize_file(input_filename, output_filename): def initialize_downsamplers(working_dir: str): for _, inputs in SUBJECT_WORKFLOW_INPUTS.items(): for _, handler in inputs.items(): - handler.downsampler = handler.downsampler(working_dir, handler.callback) + # if type(handler.transformer) == type(downsamplers.BedToIntervalListConverter): + # handler.transformer = handler.transformer(working_dir, handler.callback, ) + # exit() + handler.transformer = handler.transformer( + working_dir=working_dir, + callback=handler.callback, + reference_fasta="gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.fasta", + reference_index="gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.fasta.fai", + sequence_dict_filename="gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.dict", + picard_path="/Users/jvahid/code/picard.jar" + ) def update_workflow_json( @@ -106,7 +116,7 @@ def update_workflow_json( logging.info(f"Processing input {k}.") workflow_input_local_filename = Path(working_dir).joinpath(Path(v).name) localize_file(v, workflow_input_local_filename) - updated_files = handler.downsampler.downsample(workflow_input_local_filename, output_filename_prefix, regions) + updated_files = handler.transformer.downsample(workflow_input_local_filename, output_filename_prefix, regions) if bucket_name is not None and blob_name is not None: for varname, filename in updated_files.items(): logging.info(f"Uploading downsampled file {filename} to bucket {bucket_name}.") @@ -140,7 +150,8 @@ def main(): "In addition to other inputs, the script takes a JSON file containing the inputs to a workflow, " "downsamples the inputs according to the defined rules (see `SUBJECT_WORKFLOW_INPUTS`), " "pushes the downsampled files to a given cloud storage, and creates a new JSON " - "file with the updated downsampled inputs.", + "file with the updated downsampled inputs." + "This script needs samtools version 1.19.2 or newer installed and added to PATH.", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( @@ -161,6 +172,11 @@ def main(): "name as the input JSON with an added prefix." ) + parser.add_argument( + "picard_path", + help="Sets the absolute path, including the filename, to `picard.jar`." + ) + parser.add_argument( "--output-filename-prefix", default="downsampled_", @@ -191,6 +207,13 @@ def main(): args = parser.parse_args() + # sd = "gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.dict" + # sd_out = "./tmp_homo.dict" + # localize_file(sd, sd_out) + # test = downsamplers.BedToIntervallistConverter(".", lambda x: {}, sd_out, "/Users/jvahid/code/picard.jar") + # test.convert("default_downsampling_regions.bed", "tmp_tmp_") + # + regions = parse_target_regions(args.target_regions) logging.info(f"Found {len(regions)} target regions for downsampling.")