Skip to content

Commit

Permalink
Implement both downsampler and converter & implement a base type that…
Browse files Browse the repository at this point in the history
… both derives from.
  • Loading branch information
VJalili committed Apr 5, 2024
1 parent 94d7541 commit 501681b
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 41 deletions.
95 changes: 78 additions & 17 deletions tests/utilities/downsamplers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import pysam
import subprocess

from typing import Callable, List
from dataclasses import dataclass
Expand All @@ -12,15 +14,60 @@ class Region:
end: int


class BaseDownsampler:
class BaseTransformer:
def __init__(self, working_dir: str, callback: Callable[[str, ...], dict]):
# Convert the string to an ABS path and make sure it exists.
self.working_dir = Path(working_dir).resolve(strict=True)
self.callback = callback

@staticmethod
def get_supported_file_types() -> List[str]:
"""
The file types should include the '.' prefix to match with Path().suffix output
(e.g., it should return '.cram' instead of 'cram').
"""
raise NotImplementedError()

def get_output_filename(self, input_filename, output_prefix):
return str(self.working_dir.joinpath(f"{output_prefix}{Path(input_filename).name}"))


class BaseConverter(BaseTransformer):
def __init__(self, working_dir: str, callback: Callable[[str, ...], dict]):
super().__init__(working_dir, callback)

@staticmethod
def get_supported_file_types() -> List[str]:
raise NotImplementedError()

def convert(self, input_filename: str, output_prefix: str) -> dict:
raise NotImplementedError()


class BedToIntervalListConverter(BaseConverter):
def __init__(self, working_dir, callback: Callable[[str], dict], sequence_dict_filename: str, picard_path: str, **kwargs):
super().__init__(working_dir, callback)
self.sequence_dict_filename = sequence_dict_filename
self.picard_path = picard_path

@staticmethod
def get_supported_file_types() -> List[str]:
return [".interval_list"]

def convert(self, input_filename: str, output_prefix: str) -> dict:
output_filename = self.get_output_filename(input_filename, output_prefix)
subprocess.run(
["java", "-jar", self.picard_path, "BedToIntervalList",
"-I", input_filename, "-O", output_filename, "-SD", self.sequence_dict_filename],
check=True)
return self.callback(output_filename)



class BaseDownsampler(BaseTransformer):
def __init__(self, working_dir: str, callback: Callable[[str, ...], dict]):
super().__init__(working_dir, callback)

@staticmethod
def get_supported_file_types() -> List[str]:
"""
Expand All @@ -34,30 +81,44 @@ def downsample(self, input_filename: str, output_prefix: str, regions: List[Regi


class CramDownsampler(BaseDownsampler):
def __init__(self, working_dir, callback: Callable[[str, str], dict]):
def __init__(self, working_dir, callback: Callable[[str, str], dict], reference_fasta: str, reference_index: str, **kwargs):
super().__init__(working_dir, callback)
self.reference_fasta = reference_fasta
self.reference_index = reference_index

@staticmethod
def get_supported_file_types() -> List[str]:
return [".cram"]

def downsample(self, input_filename: str, output_prefix: str, regions: List[Region]) -> dict:
output_filename = self.get_output_filename(input_filename, output_prefix)
with \
pysam.AlignmentFile(input_filename, "rc") as input_cram_file, \
pysam.AlignmentFile(output_filename, "wc",
header=input_cram_file.header,
reference_names=input_cram_file.references) as output_cram_file:
# Implementation notes:
# 1. This method calls `samtools` instead of using `pysam` since it needs to include
# distant pair-end reads in the downsampled regions (i.e., the `--fetch-pairs` flag).
# Such reads are needed for MELT to function properly.
#
# 2. The method writes target regions to a BED file. While taking the BED file containing
# the regions as input seems a better option, the current approach is implemented as
# taking a BED file as input instead of a list of regions would convolut the method signature.

regions_filename = os.path.join(self.working_dir, "_tmp_regions.bed")
with open(regions_filename, "w") as f:
for region in regions:
for read in input_cram_file.fetch(region=f"{region.chr}:{region.start}-{region.end}"):
output_cram_file.write(read)
index_filename = f"{output_filename}.crai"
pysam.index(output_filename, index_filename)
return self.callback(output_filename, index_filename)
f.write("\t".join([str(region.chr), str(region.start), str(region.end)]) + "\n")

output_filename = self.get_output_filename(input_filename, output_prefix)
subprocess.run(
["samtools", "view", input_filename, "--reference", self.reference_fasta,
"--targets-file", regions_filename, "--output", output_filename, "--cram", "--fetch-pairs"]
)

subprocess.run(["samtools", "index", output_filename])

os.remove(regions_filename)
return self.callback(output_filename, output_filename + ".fai")


class VcfDownsampler(BaseDownsampler):
def __init__(self, working_dir, callback: Callable[[str], dict]):
def __init__(self, working_dir, callback: Callable[[str], dict], **kwargs):
super().__init__(working_dir, callback)

@staticmethod
Expand Down Expand Up @@ -89,7 +150,7 @@ class IntervalListDownsampler(BaseDownsampler):
# 1. It needs the picard tool installed;
# 2. It needs an additional "SD" input, which would make the `downsample` method signature complicated.

def __init__(self, working_dir, callback: Callable[[str], dict]):
def __init__(self, working_dir, callback: Callable[[str], dict], **kwargs):
super().__init__(working_dir, callback)

@staticmethod
Expand All @@ -114,7 +175,7 @@ def downsample(self, input_filename: str, output_prefix: str, regions: List[Regi


class BedDownsampler(BaseDownsampler):
def __init__(self, working_dir, callback: Callable[[str], dict]):
def __init__(self, working_dir, callback: Callable[[str], dict], **kwargs):
super().__init__(working_dir, callback)

@staticmethod
Expand All @@ -136,7 +197,7 @@ def downsample(self, input_filename: str, output_prefix: str, regions: List[Regi


class PrimaryContigsDownsampler(BaseDownsampler):
def __init__(self, working_dir, callback: Callable[[str], dict], delimiter: str = "\t"):
def __init__(self, working_dir, callback: Callable[[str], dict], delimiter: str = "\t", **kwargs):
super().__init__(working_dir, callback)
self.delimiter = delimiter

Expand Down
71 changes: 47 additions & 24 deletions tests/utilities/generate_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Region:

@dataclass
class Handler:
downsampler: Union[downsamplers.BaseDownsampler, Type[downsamplers.BaseDownsampler]]
transformer: Union[downsamplers.BaseTransformer, Type[downsamplers.BaseTransformer]]
callback: Callable[[str, ...], dict]


Expand All @@ -32,26 +32,26 @@ class Handler:
downsamplers.CramDownsampler,
lambda cram, index: {"bam_or_cram_file": cram, "bam_or_cram_index": index}
),
"preprocessed_intervals": Handler(
downsamplers.IntervalListDownsampler,
lambda x: {"preprocessed_intervals": x, "melt_metrics_intervals": x}
),
"sd_locs_vcf": Handler(
downsamplers.VcfDownsampler,
lambda x: {"sd_locs_vcf": x}
),
"primary_contigs_list": Handler(
downsamplers.PrimaryContigsDownsampler,
lambda x: {"primary_contigs_list": x}
),
"primary_contigs_fai": Handler(
downsamplers.PrimaryContigsDownsampler,
lambda x: {"primary_contigs_fai": x}
),
"wham_include_list_bed_file": Handler(
downsamplers.BedDownsampler,
lambda x: {"wham_include_list_bed_file": x}
)
# "preprocessed_intervals": Handler(
# downsamplers.BedToIntervalListConverter,
# lambda x: {"preprocessed_intervals": x, "melt_metrics_intervals": x}
# ),
# "sd_locs_vcf": Handler(
# downsamplers.VcfDownsampler,
# lambda x: {"sd_locs_vcf": x}
# ),
# "primary_contigs_list": Handler(
# downsamplers.PrimaryContigsDownsampler,
# lambda x: {"primary_contigs_list": x}
# ),
# "primary_contigs_fai": Handler(
# downsamplers.PrimaryContigsDownsampler,
# lambda x: {"primary_contigs_fai": x}
# ),
# "wham_include_list_bed_file": Handler(
# downsamplers.BedDownsampler,
# lambda x: {"wham_include_list_bed_file": x}
# )
}
}

Expand Down Expand Up @@ -83,7 +83,17 @@ def localize_file(input_filename, output_filename):
def initialize_downsamplers(working_dir: str):
for _, inputs in SUBJECT_WORKFLOW_INPUTS.items():
for _, handler in inputs.items():
handler.downsampler = handler.downsampler(working_dir, handler.callback)
# if type(handler.transformer) == type(downsamplers.BedToIntervalListConverter):
# handler.transformer = handler.transformer(working_dir, handler.callback, )
# exit()
handler.transformer = handler.transformer(
working_dir=working_dir,
callback=handler.callback,
reference_fasta="gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.fasta",
reference_index="gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.fasta.fai",
sequence_dict_filename="gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.dict",
picard_path="/Users/jvahid/code/picard.jar"
)


def update_workflow_json(
Expand All @@ -106,7 +116,7 @@ def update_workflow_json(
logging.info(f"Processing input {k}.")
workflow_input_local_filename = Path(working_dir).joinpath(Path(v).name)
localize_file(v, workflow_input_local_filename)
updated_files = handler.downsampler.downsample(workflow_input_local_filename, output_filename_prefix, regions)
updated_files = handler.transformer.downsample(workflow_input_local_filename, output_filename_prefix, regions)
if bucket_name is not None and blob_name is not None:
for varname, filename in updated_files.items():
logging.info(f"Uploading downsampled file {filename} to bucket {bucket_name}.")
Expand Down Expand Up @@ -140,7 +150,8 @@ def main():
"In addition to other inputs, the script takes a JSON file containing the inputs to a workflow, "
"downsamples the inputs according to the defined rules (see `SUBJECT_WORKFLOW_INPUTS`), "
"pushes the downsampled files to a given cloud storage, and creates a new JSON "
"file with the updated downsampled inputs.",
"file with the updated downsampled inputs."
"This script needs samtools version 1.19.2 or newer installed and added to PATH.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument(
Expand All @@ -161,6 +172,11 @@ def main():
"name as the input JSON with an added prefix."
)

parser.add_argument(
"picard_path",
help="Sets the absolute path, including the filename, to `picard.jar`."
)

parser.add_argument(
"--output-filename-prefix",
default="downsampled_",
Expand Down Expand Up @@ -191,6 +207,13 @@ def main():

args = parser.parse_args()

# sd = "gs://gcp-public-data--broad-references/hg38/v0/Homo_sapiens_assembly38.dict"
# sd_out = "./tmp_homo.dict"
# localize_file(sd, sd_out)
# test = downsamplers.BedToIntervallistConverter(".", lambda x: {}, sd_out, "/Users/jvahid/code/picard.jar")
# test.convert("default_downsampling_regions.bed", "tmp_tmp_")
#

regions = parse_target_regions(args.target_regions)
logging.info(f"Found {len(regions)} target regions for downsampling.")

Expand Down

0 comments on commit 501681b

Please sign in to comment.