Skip to content

Commit

Permalink
parallelize q_score
Browse files Browse the repository at this point in the history
  • Loading branch information
colinvwood committed Dec 17, 2024
1 parent e6f60a1 commit 08f009a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 33 deletions.
85 changes: 53 additions & 32 deletions q2_quality_filter/_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dataclasses import dataclass
from enum import Enum
import gzip
import multiprocessing
import os
from pathlib import Path

Expand Down Expand Up @@ -323,7 +324,7 @@ def _get_input_filepaths(


def _get_output_filepaths(
sample_id: str, input_manifest: pd.DataFrame, format: _ReadDirectionTypes
sample_id: str, input_manifest: pd.DataFrame, format_path: Path
) -> tuple[Path, Path | None]:
'''
Looks up and returns the forward fastq (and if it should exist) reverse
Expand All @@ -337,8 +338,8 @@ def _get_output_filepaths(
input_manifest : pd.DataFrame
The single-end or paired-end input fastq format's manifest viewed as a
`pd.DataFrame`.
format: _ReadDirectionTypes
A single-end or paired-end fastq directory format.
format_path: Path
The single-end or paired-end output format's path.
Returns
-------
Expand All @@ -349,9 +350,9 @@ def _get_output_filepaths(
forward_input_fp, reverse_input_fp = _get_input_filepaths(
sample_id, input_manifest
)
forward_output_fp = Path(format.path / Path(forward_input_fp).name)
forward_output_fp = format_path / Path(forward_input_fp).name
if reverse_input_fp is not None:
reverse_output_fp = Path(format.path / Path(reverse_input_fp).name)
reverse_output_fp = format_path / Path(reverse_input_fp).name
else:
reverse_output_fp = None

Expand All @@ -360,8 +361,8 @@ def _get_output_filepaths(

def _process_sample(
sample_id: str,
input_format: _ReadDirectionTypes,
output_format: _ReadDirectionTypes,
paired: bool,
output_format_path: Path,
demux_manifest: pd.DataFrame,
phred_offset: int,
min_quality: int,
Expand All @@ -380,14 +381,27 @@ def _process_sample(
----------
sample_id : str
The sample id of the to-be-processed fastq file(s).
input_format : _ReadDirectionTypes
The input format containing single-end or paired-end fastq files.
output_format : _ReadDirectionTypes
The output format to which to write processed single-end or paired-end
fastq files.
paired : bool
Whether the input data is paired-end.
output_format_path : Path
The single-end or paired-end output format's path.
demux_manifest : pd.DataFrame
The input demux manifest containing a mapping from sample id to forward
(and if present) reverse fastq filepaths.
phred_offset : int
The PHRED encoding of the record's quality scores.
min_quality : int
The minimum quality that a base must have in order to not be considered
part of a low quality window.
window_length : int
The length of the low quality window to search for.
min_length_fraction : float
The fraction of its original length a record must be greater than to
be retained.
max_ambiguous : int
The maximum number of ambiguous bases a record may contain to be
retained.
Returns
-------
Expand All @@ -412,13 +426,12 @@ def _process_sample(
sample_id, demux_manifest
)
forward_output_fp, reverse_output_fp = _get_output_filepaths(
sample_id, demux_manifest, output_format
sample_id, demux_manifest, output_format_path
)

# open output filehandle(s) and create fastq record iterator
forward_fh = gzip.open(forward_output_fp, mode='wb')

paired = isinstance(input_format, SingleLanePerSamplePairedEndFastqDirFmt)
if paired:
reverse_fh = gzip.open(reverse_output_fp, mode='wb')

Expand Down Expand Up @@ -486,7 +499,8 @@ def q_score(
min_quality: int = 4,
quality_window: int = 3,
min_length_fraction: float = 0.75,
max_ambiguous: int = 0
max_ambiguous: int = 0,
num_processes: int = 1,
) -> (_ReadDirectionUnion, pd.DataFrame):
'''
Parameter defaults as used in Bokulich et al, Nature Methods 2013, same as
Expand All @@ -512,26 +526,32 @@ def q_score(
)['phred-offset']
demux_manifest_df = demux.manifest.view(pd.DataFrame)

sample_results = []
for barcode_id, sample_id in enumerate(demux_manifest_df.index.values):
sample_stats = _process_sample(
sample_id=sample_id,
input_format=demux,
output_format=result,
demux_manifest=demux_manifest_df,
phred_offset=phred_offset,
min_quality=min_quality,
quality_window=quality_window,
min_length_fraction=min_length_fraction,
max_ambiguous=max_ambiguous,
)
# create per-sample arguments for parallel invocations
sample_ids = demux_manifest_df.index
parameters = {
'paired': paired,
'output_format_path': Path(result.path),
'demux_manifest': demux_manifest_df,
'phred_offset': phred_offset,
'min_quality': min_quality,
'quality_window': quality_window,
'min_length_fraction': min_length_fraction,
'max_ambiguous': max_ambiguous,
}
per_sample_arguments = [
[sample_id] + list(parameters.values()) for sample_id in sample_ids
]

sample_results.append(sample_stats)
# schedule samples to processes
with multiprocessing.Pool(num_processes) as pool:
all_sample_stats = pool.starmap(_process_sample, per_sample_arguments)

# update fastq manifest if sample retained
# update fastq manifest for retained samples
for sample_stats in all_sample_stats:
if sample_stats['total-retained-reads'] > 0:
sample_id = str(sample_stats.name)
forward_output_fp, reverse_output_fp = _get_output_filepaths(
str(sample_stats.name), demux_manifest_df, result
sample_id, demux_manifest_df, Path(result.path)
)
manifest_fh.write(
f'{sample_id},{forward_output_fp.name},forward\n'
Expand All @@ -543,7 +563,8 @@ def q_score(

# combine per-sample filtering stats into dataframe
non_filtered_samples = [
stats for stats in sample_results if stats['total-retained-reads'] > 0
stats for stats in all_sample_stats
if stats['total-retained-reads'] > 0
]
filtering_stats_df = pd.DataFrame(non_filtered_samples)
filtering_stats_df.index.name = 'sample-id'
Expand Down
8 changes: 7 additions & 1 deletion q2_quality_filter/plugin_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@
'min_quality': qiime2.plugin.Int,
'quality_window': qiime2.plugin.Int,
'min_length_fraction': qiime2.plugin.Float,
'max_ambiguous': qiime2.plugin.Int
'max_ambiguous': qiime2.plugin.Int,
'num_processes': qiime2.plugin.Threads,
}

_q_score_input_descriptions = {
Expand All @@ -80,6 +81,11 @@
'max_ambiguous': (
'The maximum number of ambiguous (i.e., N) base calls. This is '
'applied after trimming sequences based on `min_length_fraction`.'
),
'num_processes': (
'The number of processes to use. A higher number will improve '
'response time if the hardware resources are available. Request no '
'more than one process per available core.'
)
}

Expand Down

0 comments on commit 08f009a

Please sign in to comment.