Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always remember to flush. #146

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
332 changes: 147 additions & 185 deletions src/finaletoolkit/utils/_filter_file.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,63 @@
from __future__ import annotations
import tempfile as tf
import subprocess
import traceback
import logging
import warnings
import gzip
import pysam

def validate_deprecated_args(old_arg, new_arg, old_name, new_name):
if old_arg is not None:
warnings.warn(f"{old_name} is deprecated. Use {new_name} instead.",
category=DeprecationWarning,
stacklevel=2)
if new_arg is None:
return old_arg
raise ValueError(f'{old_name} and {new_name} cannot both be specified.')
return new_arg

def validate_input_file(input_file):
valid_suffixes = {".gz", ".bam", ".cram"}
if not any(input_file.endswith(suffix) for suffix in valid_suffixes):
raise ValueError(f"Input file should have one of the following suffixes: {', '.join(valid_suffixes)}")
return next(suffix for suffix in valid_suffixes if input_file.endswith(suffix))

def run_subprocess(cmd: str, error_msg: str = "Command failed", verbose: bool = False, logger=None):
try:
if verbose:
logger.info(f"Running: {cmd}")
subprocess.run(cmd, shell=True, check=True)
except subprocess.CalledProcessError as e:
logger.error(f"{error_msg}: {str(e)}")
raise

def filter_bed_entries(infile, min_length=None, max_length=None, quality_threshold=30):
def get_mapq_col(length):
if len(parts) < 5:
raise ValueError("There are not enough columns in the BED file to determine the MAPQ column")
return 3 if length==5 else 4

for line in infile:
if line.startswith('#'):
continue

parts = line.strip().split('\t')
mapq_column = get_mapq_col(len(parts))

try:
start = int(parts[1])
end = int(parts[2])
length = end - start
score = float(parts[mapq_column])

if ((min_length is None or length >= min_length) and
(max_length is None or length <= max_length) and
score >= quality_threshold):
yield line

except (ValueError, IndexError):
continue

def filter_file(
input_file: str,
whitelist_file: str | None = None,
Expand Down Expand Up @@ -43,7 +95,7 @@ def filter_file(
Maximum length for reads/intervals
intersect_policy: str, optional
Specifies how to determine whether fragments are in interval for
whitelisting and blacklisting functionality.'midpoint' (default)
whitelisting and blacklisting functionality. 'midpoint' (default)
calculates the central coordinate of each fragment and only
selects the fragment if the midpoint is in the interval.
'any' includes fragments with any overlap with the interval.
Expand All @@ -63,7 +115,7 @@ def filter_file(
output_file : str
Path to the filtered output file.
"""

logger = logging.getLogger(__name__)
if verbose:
print(
f"""
Expand All @@ -79,207 +131,117 @@ def filter_file(
verbose: {verbose}
"""
)

# Pass aliases and check for conflicts
if fraction_low is not None and min_length is None:
min_length = fraction_low
warnings.warn("fraction_low is deprecated. Use min_length instead.",
category=DeprecationWarning,
stacklevel=2)
elif fraction_low is not None and min_length is not None:
warnings.warn("fraction_low is deprecated. Use min_length instead.",
category=DeprecationWarning,
stacklevel=2)
raise ValueError(
'fraction_low and min_length cannot both be specified')

if fraction_high is not None and max_length is None:
max_length = fraction_high
warnings.warn("fraction_high is deprecated. Use max_length instead.",
category=DeprecationWarning,
stacklevel=2)
elif fraction_high is not None and max_length is not None:
warnings.warn("fraction_high is deprecated. Use max_length instead.",
category=DeprecationWarning,
stacklevel=2)
raise ValueError(
'fraction_high and max_length cannot both be specified.')
logging.basicConfig(level=logging.INFO)

if input_file.endswith(".gz"):
suffix = ".gz"
elif input_file.endswith(".bgz"):
suffix = ".bgz"
elif input_file.endswith(".bam"):
suffix = ".bam"
elif input_file.endswith(".cram"):
suffix = ".cram"
# Pass aliases and check for conflicts
min_length = validate_deprecated_args(fraction_low, min_length, "fraction_low", "min_length")
max_length = validate_deprecated_args(fraction_high, max_length, "fraction_high", "max_length")

suffix = validate_input_file(input_file)

if intersect_policy == "midpoint":
intersect_param = "-f 0.500"
elif intersect_policy == "any":
intersect_param = ""
else:
raise ValueError('Input file should have suffix .bam, .cram, .bgz, or .gz')

# create tempfile to contain filtered output
if output_file is None:
_, output_file = tf.mkstemp(suffix=suffix)
elif not output_file.endswith(suffix) and output_file != '-':
raise ValueError('Output file should share same suffix as input file.')

intersect = "-f 0.500" if intersect_policy == "midpoint" else ""
raise ValueError("intersect_policy must be 'midpoint' or 'any'")

pysam.set_verbosity(pysam.set_verbosity(0))

with tf.TemporaryDirectory() as temp_dir:
temp_1 = f"{temp_dir}/output1{suffix}"
temp_2 = f"{temp_dir}/output2{suffix}"
temp_3 = f"{temp_dir}/output3{suffix}"
if input_file.endswith(('.bam', '.cram')):
# create temp dir to store intermediate sorted file
if whitelist_file is not None:
try:
subprocess.run(
f"bedtools intersect -abam {input_file} -b {whitelist_file} {intersect} > {temp_1} && samtools index {temp_1}",
shell=True,
check=True)
except Exception as e:
print(e)
traceback.print_exc()
exit(1)
if whitelist_file:
run_subprocess(
f"bedtools intersect -abam {input_file} -b {whitelist_file} {intersect_param} > {temp_1} && "
f"samtools index {temp_1}",
error_msg="Whitelist filtering failed",
verbose=verbose,
logger=logger
)
else:
subprocess.run(
f"cp {input_file} {temp_1}", shell=True, check=True)
if blacklist_file is not None:
try:
subprocess.run(
f"bedtools intersect -abam {temp_1} -b {blacklist_file} -v {intersect} > {temp_2} && samtools index {temp_2}",
shell=True,
check=True)
except Exception:
traceback.print_exc()
exit(1)
run_subprocess(f"cp {input_file} {temp_1}", verbose=verbose, logger=logger)
if blacklist_file:
intersect_param = "-f 0.500" if intersect_policy == "midpoint" else ""
run_subprocess(
f"bedtools intersect -abam {temp_1} -b {blacklist_file} -v {intersect_param} > {temp_2} && "
f"samtools index {temp_2}",
error_msg="Blacklist filtering failed",
verbose=verbose,
logger=logger
)
else:
subprocess.run(
f"mv {temp_1} {temp_2}", shell=True, check=True)

try:
subprocess.run(
f"samtools view {temp_2} -F 3852 -f 3 -b -h -o {temp_3} -q {quality_threshold} -@ {workers}",
shell=True,
check=True)
except Exception:
traceback.print_exc()
exit(1)
run_subprocess(f"mv {temp_1} {temp_2}", verbose=verbose, logger=logger)

run_subprocess(
f"samtools view {temp_2} -F 3852 -f 3 -b -h -o {temp_3} -q {quality_threshold} -@ {workers}",
error_msg="Quality filtering failed",
verbose=verbose,
logger=logger
)

# filter for reads on different reference and length
with pysam.AlignmentFile(temp_3, 'rb',threads=workers//3) as in_file:
with pysam.AlignmentFile(
output_file, 'wb', template=in_file, threads=workers-workers//3) as out_file:
# Length filtering and final output
pysam.set_verbosity(0)
with pysam.AlignmentFile(temp_3, 'rb', threads=workers//3) as in_file:
with pysam.AlignmentFile(output_file, 'wb', template=in_file, threads=workers-workers//3) as out_file:
for read in in_file:
if (
read.reference_name == read.next_reference_name
and (max_length is None
or read.template_length <= max_length)
and (min_length is None
or read.template_length >= min_length)
):
if (read.reference_name == read.next_reference_name and
(max_length is None or read.template_length <= max_length) and
(min_length is None or read.template_length >= min_length)):
out_file.write(read)
outfile.flush()

if output_file != '-':
# generate index for output_file
try:
subprocess.run(
f'samtools index {output_file} {output_file}.bai',
shell=True,
check=True
run_subprocess(
f'samtools index {output_file}',
error_msg="Index creation failed",
verbose=verbose,
logger=logger
)

elif input_file.endswith('.gz'):
with gzip.open(input_file, 'rt') as infile, open(temp_1, 'w') as outfile:
for line in filter_bed_entries(infile, min_length, max_length, quality_threshold):
outfile.write(line)
outfile.flush()

if whitelist_file:
intersect_param = "-f 0.500" if intersect_policy == "midpoint" else ""
run_subprocess(
f"bedtools intersect -a {temp_1} -b {whitelist_file} {intersect_param} > {temp_2}",
error_msg="Whitelist filtering failed",
verbose=verbose,
logger=logger
)
except Exception:
traceback.print_exc()
exit(1)

elif input_file.endswith('.gz') or input_file.endswith('.bgz'):
with gzip.open(input_file, 'r') as infile, open(temp_1, 'w') as outfile:
mapq_column = 0 # 1-index for sanity when comparing with len()
for line in infile:
line = line.decode('utf-8')
parts = line.strip().split('\t')
if len(parts) < max(mapq_column,4) or line.startswith('#'):
continue

if mapq_column == 0:
if parts[4-1].isnumeric():
mapq_column = 4
elif len(parts) >= 5 and parts[5-1].isnumeric():
mapq_column = 5
else:
continue
try:
start = int(parts[1])
end = int(parts[2])
length = end - start
score = None
try:
score = float(parts[mapq_column-1])
except ValueError:
pass

passes_length_restriction = True

if min_length is not None and length < min_length:
passes_length_restriction = False

if max_length is not None and length > max_length:
passes_length_restriction = False

passes_quality_restriction = True
if score is None or score < quality_threshold:
passes_quality_restriction = False

if passes_length_restriction and passes_quality_restriction:
outfile.write(line)
except ValueError:
continue
if whitelist_file is not None:
try:
subprocess.run(
f"bedtools intersect -a {temp_1} -b {whitelist_file} {intersect} > {temp_2}",
shell=True,
check=True
)
except Exception:
traceback.print_exc()
exit(1)
else:
subprocess.run(f"mv {temp_1} {temp_2}", shell=True, check=True)

if blacklist_file is not None:
try:
subprocess.run(
f"bedtools intersect -v -a {temp_2} -b {blacklist_file} {intersect} > {temp_3}",
shell=True,
check=True
)
except Exception:
traceback.print_exc()
exit(1)
run_subprocess(f"mv {temp_1} {temp_2}", verbose=verbose, logger=logger)

if blacklist_file:
intersect_param = "-f 0.500" if intersect_policy == "midpoint" else ""
run_subprocess(
f"bedtools intersect -v -a {temp_2} -b {blacklist_file} {intersect_param} > {temp_3}",
error_msg="Blacklist filtering failed",
verbose=verbose,
logger=logger
)
else:
subprocess.run(f"mv {temp_2} {temp_3}", shell=True, check=True)
try:
subprocess.run(
f"bgzip -@ {workers} -c {temp_3} > {output_file}",
shell=True,
check=True
)
except Exception:
traceback.print_exc()
exit(1)
run_subprocess(f"mv {temp_2} {temp_3}", verbose=verbose, logger=logger)

# Compression and indexing
run_subprocess(
f"bgzip -@ {workers} -c {temp_3} > {output_file}",
error_msg="Compression failed",
verbose=verbose,
logger=logger
)

if output_file != '-':
# generate index for output_file
try:
subprocess.run(
f'tabix -p bed {output_file}',
shell=True,
check=True
)
except Exception:
traceback.print_exc()
exit(1)
else:
raise ValueError("Input file must be a BAM, CRAM, or bgzipped BED file.")
run_subprocess(
f'tabix -p bed {output_file}',
error_msg="Index creation failed",
verbose=verbose,
logger=logger
)
return output_file
Loading