From 08f009ac2e5caacebd44346f6763617e9d0d0c44 Mon Sep 17 00:00:00 2001 From: Colin Wood Date: Tue, 17 Dec 2024 09:51:15 -0700 Subject: [PATCH] parallelize q_score --- q2_quality_filter/_filter.py | 85 +++++++++++++++++++------------ q2_quality_filter/plugin_setup.py | 8 ++- 2 files changed, 60 insertions(+), 33 deletions(-) diff --git a/q2_quality_filter/_filter.py b/q2_quality_filter/_filter.py index 468825b..53167dc 100644 --- a/q2_quality_filter/_filter.py +++ b/q2_quality_filter/_filter.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from enum import Enum import gzip +import multiprocessing import os from pathlib import Path @@ -323,7 +324,7 @@ def _get_input_filepaths( def _get_output_filepaths( - sample_id: str, input_manifest: pd.DataFrame, format: _ReadDirectionTypes + sample_id: str, input_manifest: pd.DataFrame, format_path: Path ) -> tuple[Path, Path | None]: ''' Looks up and returns the forward fastq (and if it should exist) reverse @@ -337,8 +338,8 @@ def _get_output_filepaths( input_manifest : pd.DataFrame The single-end or paired-end input fastq format's manifest viewed as a `pd.DataFrame`. - format: _ReadDirectionTypes - A single-end or paired-end fastq directory format. + format_path: Path + The single-end or paired-end output format's path. Returns ------- @@ -349,9 +350,9 @@ def _get_output_filepaths( forward_input_fp, reverse_input_fp = _get_input_filepaths( sample_id, input_manifest ) - forward_output_fp = Path(format.path / Path(forward_input_fp).name) + forward_output_fp = format_path / Path(forward_input_fp).name if reverse_input_fp is not None: - reverse_output_fp = Path(format.path / Path(reverse_input_fp).name) + reverse_output_fp = format_path / Path(reverse_input_fp).name else: reverse_output_fp = None @@ -360,8 +361,8 @@ def _get_output_filepaths( def _process_sample( sample_id: str, - input_format: _ReadDirectionTypes, - output_format: _ReadDirectionTypes, + paired: bool, + output_format_path: Path, demux_manifest: pd.DataFrame, phred_offset: int, min_quality: int, @@ -380,14 +381,27 @@ def _process_sample( ---------- sample_id : str The sample id of the to-be-processed fastq file(s). - input_format : _ReadDirectionTypes - The input format containing single-end or paired-end fastq files. - output_format : _ReadDirectionTypes - The output format to which to write processed single-end or paired-end - fastq files. + paired : bool + Whether the input data is paired-end. + output_format_path : Path + The single-end or paired-end output format's path. demux_manifest : pd.DataFrame The input demux manifest containing a mapping from sample id to forward (and if present) reverse fastq filepaths. + phred_offset : int + The PHRED encoding of the record's quality scores. + min_quality : int + The minimum quality that a base must have in order to not be considered + part of a low quality window. + window_length : int + The length of the low quality window to search for. + min_length_fraction : float + The fraction of its original length a record must be greater than to + be retained. + max_ambiguous : int + The maximum number of ambiguous bases a record may contain to be + retained. + Returns ------- @@ -412,13 +426,12 @@ def _process_sample( sample_id, demux_manifest ) forward_output_fp, reverse_output_fp = _get_output_filepaths( - sample_id, demux_manifest, output_format + sample_id, demux_manifest, output_format_path ) # open output filehandle(s) and create fastq record iterator forward_fh = gzip.open(forward_output_fp, mode='wb') - paired = isinstance(input_format, SingleLanePerSamplePairedEndFastqDirFmt) if paired: reverse_fh = gzip.open(reverse_output_fp, mode='wb') @@ -486,7 +499,8 @@ def q_score( min_quality: int = 4, quality_window: int = 3, min_length_fraction: float = 0.75, - max_ambiguous: int = 0 + max_ambiguous: int = 0, + num_processes: int = 1, ) -> (_ReadDirectionUnion, pd.DataFrame): ''' Parameter defaults as used in Bokulich et al, Nature Methods 2013, same as @@ -512,26 +526,32 @@ def q_score( )['phred-offset'] demux_manifest_df = demux.manifest.view(pd.DataFrame) - sample_results = [] - for barcode_id, sample_id in enumerate(demux_manifest_df.index.values): - sample_stats = _process_sample( - sample_id=sample_id, - input_format=demux, - output_format=result, - demux_manifest=demux_manifest_df, - phred_offset=phred_offset, - min_quality=min_quality, - quality_window=quality_window, - min_length_fraction=min_length_fraction, - max_ambiguous=max_ambiguous, - ) + # create per-sample arguments for parallel invocations + sample_ids = demux_manifest_df.index + parameters = { + 'paired': paired, + 'output_format_path': Path(result.path), + 'demux_manifest': demux_manifest_df, + 'phred_offset': phred_offset, + 'min_quality': min_quality, + 'quality_window': quality_window, + 'min_length_fraction': min_length_fraction, + 'max_ambiguous': max_ambiguous, + } + per_sample_arguments = [ + [sample_id] + list(parameters.values()) for sample_id in sample_ids + ] - sample_results.append(sample_stats) + # schedule samples to processes + with multiprocessing.Pool(num_processes) as pool: + all_sample_stats = pool.starmap(_process_sample, per_sample_arguments) - # update fastq manifest if sample retained + # update fastq manifest for retained samples + for sample_stats in all_sample_stats: if sample_stats['total-retained-reads'] > 0: + sample_id = str(sample_stats.name) forward_output_fp, reverse_output_fp = _get_output_filepaths( - str(sample_stats.name), demux_manifest_df, result + sample_id, demux_manifest_df, Path(result.path) ) manifest_fh.write( f'{sample_id},{forward_output_fp.name},forward\n' @@ -543,7 +563,8 @@ def q_score( # combine per-sample filtering stats into dataframe non_filtered_samples = [ - stats for stats in sample_results if stats['total-retained-reads'] > 0 + stats for stats in all_sample_stats + if stats['total-retained-reads'] > 0 ] filtering_stats_df = pd.DataFrame(non_filtered_samples) filtering_stats_df.index.name = 'sample-id' diff --git a/q2_quality_filter/plugin_setup.py b/q2_quality_filter/plugin_setup.py index c7ce738..fd618c7 100644 --- a/q2_quality_filter/plugin_setup.py +++ b/q2_quality_filter/plugin_setup.py @@ -55,7 +55,8 @@ 'min_quality': qiime2.plugin.Int, 'quality_window': qiime2.plugin.Int, 'min_length_fraction': qiime2.plugin.Float, - 'max_ambiguous': qiime2.plugin.Int + 'max_ambiguous': qiime2.plugin.Int, + 'num_processes': qiime2.plugin.Threads, } _q_score_input_descriptions = { @@ -80,6 +81,11 @@ 'max_ambiguous': ( 'The maximum number of ambiguous (i.e., N) base calls. This is ' 'applied after trimming sequences based on `min_length_fraction`.' + ), + 'num_processes': ( + 'The number of processes to use. A higher number will improve ' + 'response time if the hardware resources are available. Request no ' + 'more than one process per available core.' ) }