diff --git a/q2_quality_filter/_filter.py b/q2_quality_filter/_filter.py index 9d5b79e..cbb6ced 100644 --- a/q2_quality_filter/_filter.py +++ b/q2_quality_filter/_filter.py @@ -9,6 +9,8 @@ from dataclasses import dataclass from enum import Enum import gzip +import functools +import multiprocessing import os from pathlib import Path @@ -225,8 +227,7 @@ def _process_record( def _is_retained( forward_status: RecordStatus, reverse_status: RecordStatus | None, - filtering_stats_df: pd.DataFrame, - sample_id: str + filtering_stats: pd.Series, ) -> bool: ''' Determines whether a fastq record or pair of fastq records will retained @@ -240,42 +241,34 @@ def _is_retained( reverse_status : RecordStatus or None The status of the record from the reverse fastq file if it exists otherwise None. - filtering_stats_df : pd.DataFrame - The data structure that tracks filtering stats. - sample_id : str - The sample id that the record(s) belongs to. + filtering_stats : pd.Series + The data structure that tracks filtering stats for `sample_id`. Returns ------- bool True if the record(s) is to be retained, False otherwise. ''' - filtering_stats_df.loc[sample_id, 'total-input-reads'] += 1 + filtering_stats['total-input-reads'] += 1 if (RecordStatus.SHORT in (forward_status, reverse_status)): - filtering_stats_df.loc[sample_id, 'reads-truncated'] += 1 - filtering_stats_df.loc[ - sample_id, 'reads-too-short-after-truncation' - ] += 1 + filtering_stats['reads-truncated'] += 1 + filtering_stats['reads-too-short-after-truncation'] += 1 return False if (RecordStatus.AMBIGUOUS in (forward_status, reverse_status)): - filtering_stats_df.loc[ - sample_id, 'reads-exceeding-maximum-ambiguous-bases' - ] += 1 + filtering_stats['reads-exceeding-maximum-ambiguous-bases'] += 1 return False if (RecordStatus.TRUNCATED_AMBIGUOUS in (forward_status, reverse_status)): - filtering_stats_df.loc[sample_id, 'reads-truncated'] += 1 - filtering_stats_df.loc[ - sample_id, 'reads-exceeding-maximum-ambiguous-bases' - ] += 1 + filtering_stats['reads-truncated'] += 1 + filtering_stats['reads-exceeding-maximum-ambiguous-bases'] += 1 return False if (RecordStatus.TRUNCATED in (forward_status, reverse_status)): - filtering_stats_df.loc[sample_id, 'reads-truncated'] += 1 + filtering_stats['reads-truncated'] += 1 - filtering_stats_df.loc[sample_id, 'total-retained-reads'] += 1 + filtering_stats['total-retained-reads'] += 1 return True @@ -301,12 +294,214 @@ def _write_record(fastq_record: FastqRecord, fh: gzip.GzipFile) -> None: fh.write(fastq_record.quality_scores + b'\n') +def _get_input_filepaths( + sample_id: str, manifest: pd.DataFrame +) -> tuple[Path, Path | None]: + ''' + Looks up and returns the forward fastq (and if it exists) reverse fastq + absolute filepaths in the input format's `manifest` for `sample_id`. + + Parameters + ---------- + sample_id : str + The sample id of interest. + manifest : pd.DataFrame + A single-end or paired-end fastq format's manifest viewed as a + `pd.DataFrame`. + + Returns + ------- + tuple[Path, Path | None] + A tuple containing the forward filepath and the reverse filepath if it + exists, otherwise None. + ''' + forward_fp = Path(manifest.loc[sample_id, 'forward']) + try: + reverse_fp = Path(manifest.loc[sample_id, 'reverse']) + except KeyError: + reverse_fp = None + + return forward_fp, reverse_fp + + +def _get_output_filepaths( + sample_id: str, input_manifest: pd.DataFrame, format_path: Path +) -> tuple[Path, Path | None]: + ''' + Looks up and returns the forward fastq (and if it should exist) reverse + fastq absolute filepaths that belong in the output format for `sample_id`. + Note that these filepaths may or may not already exist. + + Parameters + ---------- + sample_id : str + The sample id of interest. + input_manifest : pd.DataFrame + The single-end or paired-end input fastq format's manifest viewed as a + `pd.DataFrame`. + format_path: Path + The single-end or paired-end output format's path. + + Returns + ------- + tuple[Path, Path | None] + A tuple containing the forward filepath and the reverse filepath if it + should exist, otherwise None. + ''' + forward_input_fp, reverse_input_fp = _get_input_filepaths( + sample_id, input_manifest + ) + forward_output_fp = format_path / Path(forward_input_fp).name + if reverse_input_fp is not None: + reverse_output_fp = format_path / Path(reverse_input_fp).name + else: + reverse_output_fp = None + + return forward_output_fp, reverse_output_fp + + +def _process_sample( + sample_id: str, + paired: bool, + output_format_path: Path, + demux_manifest: pd.DataFrame, + phred_offset: int, + min_quality: int, + quality_window: int, + min_length_fraction: float, + max_ambiguous: int, +) -> pd.Series: + ''' + Processes the fastq records belonging to a single sample. Intended to be + the unit of parallelization. + + Reads records from `input_format`, processes them, and writes the processed + records to `output_format`. Collects and returns filtering statistics. + + Parameters + ---------- + sample_id : str + The sample id of the to-be-processed fastq file(s). + paired : bool + Whether the input data is paired-end. + output_format_path : Path + The single-end or paired-end output format's path. + demux_manifest : pd.DataFrame + The input demux manifest containing a mapping from sample id to forward + (and if present) reverse fastq filepaths. + phred_offset : int + The PHRED encoding of the record's quality scores. + min_quality : int + The minimum quality that a base must have in order to not be considered + part of a low quality window. + window_length : int + The length of the low quality window to search for. + min_length_fraction : float + The fraction of its original length a record must be greater than to + be retained. + max_ambiguous : int + The maximum number of ambiguous bases a record may contain to be + retained. + + + Returns + ------- + pd.Series + The processed sample's filtering statistics. + ''' + # initialize filtering stats + filtering_stats = pd.Series( + data=0, + name=sample_id, + index=[ + 'total-input-reads', + 'total-retained-reads', + 'reads-truncated', + 'reads-too-short-after-truncation', + 'reads-exceeding-maximum-ambiguous-bases' + ] + ) + + # get filepath(s) of input/output fastq file(s) + forward_input_fp, reverse_input_fp = _get_input_filepaths( + sample_id, demux_manifest + ) + forward_output_fp, reverse_output_fp = _get_output_filepaths( + sample_id, demux_manifest, output_format_path + ) + + # open output filehandle(s) and create fastq record iterator + forward_fh = gzip.open(forward_output_fp, mode='wb') + + if paired: + reverse_fh = gzip.open(reverse_output_fp, mode='wb') + + forward_iterator = _read_fastq_records(str(forward_input_fp)) + reverse_iterator = _read_fastq_records(str(reverse_input_fp)) + iterator = zip(forward_iterator, reverse_iterator) + else: + iterator = _read_fastq_records(str(forward_input_fp)) + + # process records + for fastq_record in iterator: + if paired: + forward_record, reverse_record = fastq_record + else: + forward_record = fastq_record + reverse_record = None + + forward_record, forward_status = _process_record( + fastq_record=forward_record, + phred_offset=phred_offset, + min_quality=min_quality, + window_length=quality_window + 1, + min_length_fraction=min_length_fraction, + max_ambiguous=max_ambiguous + ) + reverse_record, reverse_status = _process_record( + fastq_record=reverse_record, + phred_offset=phred_offset, + min_quality=min_quality, + window_length=quality_window + 1, + min_length_fraction=min_length_fraction, + max_ambiguous=max_ambiguous + ) + + # see if record(s) retained and update filtering stats + retained = _is_retained( + forward_status, reverse_status, filtering_stats + ) + + # if retained write to output file(s) + if retained: + if paired: + _write_record(forward_record, forward_fh) + _write_record(reverse_record, reverse_fh) + else: + _write_record(forward_record, forward_fh) + + # close output file(s) + forward_fh.close() + if paired: + reverse_fh.close() + + # delete output files if no records retained + if filtering_stats['total-retained-reads'] == 0: + os.remove(forward_output_fp) + if paired: + os.remove(reverse_output_fp) + + # return statistics + return filtering_stats + + def q_score( demux: _ReadDirectionTypes, min_quality: int = 4, quality_window: int = 3, min_length_fraction: float = 0.75, - max_ambiguous: int = 0 + max_ambiguous: int = 0, + num_processes: int = 1, ) -> (_ReadDirectionUnion, pd.DataFrame): ''' Parameter defaults as used in Bokulich et al, Nature Methods 2013, same as @@ -332,87 +527,33 @@ def q_score( )['phred-offset'] demux_manifest_df = demux.manifest.view(pd.DataFrame) - # initialize filtering stats tracking dataframe - filtering_stats_df = pd.DataFrame( - data=0, - index=demux_manifest_df.index, - columns=[ - 'total-input-reads', - 'total-retained-reads', - 'reads-truncated', - 'reads-too-short-after-truncation', - 'reads-exceeding-maximum-ambiguous-bases' - ] + # create per-sample functions and sample_id arguments for parallel + # invocations + sample_ids = [(sample_id,) for sample_id in demux_manifest_df.index] + + _process_sample_partial = functools.partial( + _process_sample, + paired=paired, + output_format_path=Path(result.path), + demux_manifest=demux_manifest_df, + phred_offset=phred_offset, + min_quality=min_quality, + quality_window=quality_window, + min_length_fraction=min_length_fraction, + max_ambiguous=max_ambiguous ) - for barcode_id, sample_id in enumerate(demux_manifest_df.index.values): - # get/create filepath(s) of input/output fastq file(s) - forward_input_fp = demux_manifest_df.loc[sample_id, 'forward'] - forward_output_fp = Path(result.path) / Path(forward_input_fp).name - if paired: - reverse_input_fp = demux_manifest_df.loc[sample_id, 'reverse'] - reverse_output_fp = Path(result.path) / Path(reverse_input_fp).name - - # open output filehandle(s) and create fastq record iterator - forward_fh = gzip.open(forward_output_fp, mode='wb') - - if paired: - reverse_fh = gzip.open(reverse_output_fp, mode='wb') + # schedule samples to processes + with multiprocessing.Pool(num_processes) as pool: + all_sample_stats = pool.starmap(_process_sample_partial, sample_ids) - forward_iterator = _read_fastq_records(str(forward_input_fp)) - reverse_iterator = _read_fastq_records(str(reverse_input_fp)) - iterator = zip(forward_iterator, reverse_iterator) - else: - iterator = _read_fastq_records(str(forward_input_fp)) - - for fastq_record in iterator: - if paired: - forward_record, reverse_record = fastq_record - else: - forward_record = fastq_record - reverse_record = None - - # process records - forward_record, forward_status = _process_record( - fastq_record=forward_record, - phred_offset=phred_offset, - min_quality=min_quality, - window_length=quality_window + 1, - min_length_fraction=min_length_fraction, - max_ambiguous=max_ambiguous - ) - reverse_record, reverse_status = _process_record( - fastq_record=reverse_record, - phred_offset=phred_offset, - min_quality=min_quality, - window_length=quality_window + 1, - min_length_fraction=min_length_fraction, - max_ambiguous=max_ambiguous - ) - - # see if record(s) retained and update filtering stats - retained = _is_retained( - forward_status, - reverse_status, - filtering_stats_df, - sample_id + # update fastq manifest for retained samples + for sample_stats in all_sample_stats: + if sample_stats['total-retained-reads'] > 0: + sample_id = str(sample_stats.name) + forward_output_fp, reverse_output_fp = _get_output_filepaths( + sample_id, demux_manifest_df, Path(result.path) ) - - # if retained write to output file(s) - if retained: - if paired: - _write_record(forward_record, forward_fh) - _write_record(reverse_record, reverse_fh) - else: - _write_record(forward_record, forward_fh) - - # close output file(s) and update manifest if record(s) retained, - # otherwise delete the empty file(s) - forward_fh.close() - if paired: - reverse_fh.close() - - if filtering_stats_df.loc[sample_id, 'total-retained-reads'] > 0: manifest_fh.write( f'{sample_id},{forward_output_fp.name},forward\n' ) @@ -420,13 +561,17 @@ def q_score( manifest_fh.write( f'{sample_id},{reverse_output_fp.name},reverse\n' ) - else: - os.remove(forward_output_fp) - if paired: - os.remove(reverse_output_fp) + + # combine per-sample filtering stats into dataframe + non_filtered_samples = [ + stats for stats in all_sample_stats + if stats['total-retained-reads'] > 0 + ] + filtering_stats_df = pd.DataFrame(non_filtered_samples) + filtering_stats_df.index.name = 'sample-id' # error if all samples retained no reads - if filtering_stats_df['total-retained-reads'].sum() == 0: + if filtering_stats_df.empty: msg = ( 'All sequences from all samples were filtered. The parameter ' 'choices may have been too stringent for the data.' diff --git a/q2_quality_filter/plugin_setup.py b/q2_quality_filter/plugin_setup.py index 0ba9ac6..b697096 100644 --- a/q2_quality_filter/plugin_setup.py +++ b/q2_quality_filter/plugin_setup.py @@ -55,7 +55,8 @@ 'min_quality': qiime2.plugin.Int, 'quality_window': qiime2.plugin.Int, 'min_length_fraction': qiime2.plugin.Float, - 'max_ambiguous': qiime2.plugin.Int + 'max_ambiguous': qiime2.plugin.Int, + 'num_processes': qiime2.plugin.Threads, } _q_score_input_descriptions = { @@ -80,6 +81,11 @@ 'max_ambiguous': ( 'The maximum number of ambiguous (i.e., N) base calls. This is ' 'applied after trimming sequences based on `min_length_fraction`.' + ), + 'num_processes': ( + 'The number of processes to use. A higher number will improve ' + 'response time if the hardware resources are available. Request no ' + 'more than one process per available core.' ) } diff --git a/q2_quality_filter/tests/test_filter.py b/q2_quality_filter/tests/test_filter.py index 7ceffde..49ff5b1 100644 --- a/q2_quality_filter/tests/test_filter.py +++ b/q2_quality_filter/tests/test_filter.py @@ -255,10 +255,10 @@ def test_process_record(self): self.assertEqual(status, exp_status) def test_is_retained(self): - filtering_stats_df = pd.DataFrame( + filtering_stats = pd.Series( data=0, - index=['sample-a', 'sample-b', 'sample-c'], - columns=[ + name='sample-a', + index=[ 'total-input-reads', 'total-retained-reads', 'reads-truncated', @@ -271,83 +271,54 @@ def test_is_retained(self): retained = _is_retained( forward_status=RecordStatus.TRUNCATED, reverse_status=RecordStatus.UNTRUNCATED, - filtering_stats_df=filtering_stats_df, - sample_id='sample-a' + filtering_stats=filtering_stats, ) self.assertTrue(retained) - self.assertEqual( - filtering_stats_df.loc['sample-a', 'total-retained-reads'], 1 - ) - self.assertEqual( - filtering_stats_df.loc['sample-a', 'reads-truncated'], 1 - ) - filtering_stats_df.iloc[:, :] = 0 + self.assertEqual(filtering_stats['total-retained-reads'], 1) + self.assertEqual(filtering_stats['reads-truncated'], 1) + filtering_stats[:] = 0 # forward read only, retained retained = _is_retained( forward_status=RecordStatus.TRUNCATED, reverse_status=None, - filtering_stats_df=filtering_stats_df, - sample_id='sample-a' + filtering_stats=filtering_stats, ) self.assertTrue(retained) + self.assertEqual(filtering_stats['total-retained-reads'], 1) + self.assertEqual(filtering_stats['reads-truncated'], 1) self.assertEqual( - filtering_stats_df.loc['sample-a', 'total-retained-reads'], 1 - ) - self.assertEqual( - filtering_stats_df.loc['sample-a', 'reads-truncated'], 1 + filtering_stats['reads-too-short-after-truncation'], 0 ) - self.assertEqual( - filtering_stats_df.loc[ - 'sample-a', 'reads-too-short-after-truncation' - ], - 0 - ) - filtering_stats_df.iloc[:, :] = 0 + filtering_stats[:] = 0 # forward read only, short retained = _is_retained( forward_status=RecordStatus.SHORT, reverse_status=None, - filtering_stats_df=filtering_stats_df, - sample_id='sample-b' + filtering_stats=filtering_stats, ) self.assertFalse(retained) + self.assertEqual(filtering_stats['total-retained-reads'], 0) + self.assertEqual(filtering_stats['reads-truncated'], 1) self.assertEqual( - filtering_stats_df.loc['sample-b', 'total-retained-reads'], 0 + filtering_stats['reads-too-short-after-truncation'], 1 ) - self.assertEqual( - filtering_stats_df.loc['sample-b', 'reads-truncated'], 1 - ) - self.assertEqual( - filtering_stats_df.loc[ - 'sample-b', 'reads-too-short-after-truncation' - ], - 1 - ) - filtering_stats_df.iloc[:, :] = 0 + filtering_stats[:] = 0 # one read untruncated, one read truncated and ambiguous retained = _is_retained( forward_status=RecordStatus.UNTRUNCATED, reverse_status=RecordStatus.TRUNCATED_AMBIGUOUS, - filtering_stats_df=filtering_stats_df, - sample_id='sample-a' + filtering_stats=filtering_stats, ) self.assertFalse(retained) + self.assertEqual(filtering_stats['total-retained-reads'], 0) self.assertEqual( - filtering_stats_df.loc['sample-a', 'total-retained-reads'], 0 - ) - self.assertEqual( - filtering_stats_df.loc[ - 'sample-a', 'reads-exceeding-maximum-ambiguous-bases' - ], - 1 - ) - self.assertEqual( - filtering_stats_df.loc['sample-a', 'reads-truncated'], 1 + filtering_stats['reads-exceeding-maximum-ambiguous-bases'], 1 ) - filtering_stats_df.iloc[:, :] = 0 + self.assertEqual(filtering_stats['reads-truncated'], 1) + filtering_stats[:] = 0 def test_write_record(self): fastq_record = FastqRecord( @@ -404,58 +375,73 @@ def test_q_score_numeric_ids(self): self.assertEqual(obs_sids, exp_sids) self.assertEqual(set(stats.index), exp_sids) - def test_q_score(self): + def test_q_score_different_num_processes(self): ar = Artifact.load(self.get_data_path('simple.qza')) - with redirected_stdio(stdout=os.devnull): - obs_drop_ambig_ar, stats_ar = self.plugin.methods['q_score']( - ar, quality_window=2, min_quality=20, min_length_fraction=0.25) - obs_drop_ambig = obs_drop_ambig_ar.view( - SingleLanePerSampleSingleEndFastqDirFmt) - stats = stats_ar.view(pd.DataFrame) - exp_drop_ambig = ["@foo_1", - "ATGCATGC", - "+", - "DDDDBBDD"] - columns = ['sample-id', 'total-input-reads', 'total-retained-reads', - 'reads-truncated', - 'reads-too-short-after-truncation', - 'reads-exceeding-maximum-ambiguous-bases'] - exp_drop_ambig_stats = pd.DataFrame([('foo', 2, 1, 0, 0, 1), - ('bar', 1, 0, 0, 0, 1)], - columns=columns) - exp_drop_ambig_stats = exp_drop_ambig_stats.set_index('sample-id') - obs = [] - iterator = obs_drop_ambig.sequences.iter_views(FastqGzFormat) - for sample_id, fp in iterator: - obs.extend([x.strip() for x in gzip.open(str(fp), 'rt')]) - self.assertEqual(obs, exp_drop_ambig) - pdt.assert_frame_equal(stats, exp_drop_ambig_stats.loc[stats.index]) - - with redirected_stdio(stdout=os.devnull): - obs_trunc_ar, stats_ar = self.plugin.methods['q_score']( - ar, quality_window=1, min_quality=33, min_length_fraction=0.25) - obs_trunc = obs_trunc_ar.view(SingleLanePerSampleSingleEndFastqDirFmt) - stats = stats_ar.view(pd.DataFrame) + for num_processes in (1, 2): + with redirected_stdio(stdout=os.devnull): + obs_drop_ambig_ar, stats_ar = self.plugin.methods['q_score']( + ar, + quality_window=2, + min_quality=20, + min_length_fraction=0.25, + num_processes=num_processes, + ) + obs_drop_ambig = obs_drop_ambig_ar.view( + SingleLanePerSampleSingleEndFastqDirFmt) + stats = stats_ar.view(pd.DataFrame) - exp_trunc = ["@foo_1", - "ATGCATGC", - "+", - "DDDDBBDD", - "@bar_1", - "ATA", - "+", - "DDD"] - exp_trunc_stats = pd.DataFrame([('foo', 2, 1, 0, 0, 1), - ('bar', 1, 1, 1, 0, 0)], - columns=columns) - exp_trunc_stats = exp_trunc_stats.set_index('sample-id') + exp_drop_ambig = ["@foo_1", "ATGCATGC", "+", "DDDDBBDD"] + columns = [ + 'sample-id', + 'total-input-reads', + 'total-retained-reads', + 'reads-truncated', + 'reads-too-short-after-truncation', + 'reads-exceeding-maximum-ambiguous-bases' + ] + exp_drop_ambig_stats = pd.DataFrame( + [('foo', 2, 1, 0, 0, 1), ('bar', 1, 0, 0, 0, 1)], + columns=columns + ) + exp_drop_ambig_stats = exp_drop_ambig_stats.set_index('sample-id') + obs = [] + iterator = obs_drop_ambig.sequences.iter_views(FastqGzFormat) + for sample_id, fp in iterator: + obs.extend([x.strip() for x in gzip.open(str(fp), 'rt')]) + self.assertEqual(obs, exp_drop_ambig) + pdt.assert_frame_equal( + stats, exp_drop_ambig_stats.loc[stats.index] + ) - obs = [] - for sample_id, fp in obs_trunc.sequences.iter_views(FastqGzFormat): - obs.extend([x.strip() for x in gzip.open(str(fp), 'rt')]) - self.assertEqual(sorted(obs), sorted(exp_trunc)) - pdt.assert_frame_equal(stats, exp_trunc_stats.loc[stats.index]) + with redirected_stdio(stdout=os.devnull): + obs_trunc_ar, stats_ar = self.plugin.methods['q_score']( + ar, + quality_window=1, + min_quality=33, + min_length_fraction=0.25, + num_processes=num_processes, + ) + obs_trunc = obs_trunc_ar.view( + SingleLanePerSampleSingleEndFastqDirFmt + ) + stats = stats_ar.view(pd.DataFrame) + + exp_trunc = [ + "@foo_1", "ATGCATGC", "+", "DDDDBBDD", + "@bar_1", "ATA", "+", "DDD" + ] + exp_trunc_stats = pd.DataFrame( + [('foo', 2, 1, 0, 0, 1), ('bar', 1, 1, 1, 0, 0)], + columns=columns + ) + exp_trunc_stats = exp_trunc_stats.set_index('sample-id') + + obs = [] + for sample_id, fp in obs_trunc.sequences.iter_views(FastqGzFormat): + obs.extend([x.strip() for x in gzip.open(str(fp), 'rt')]) + self.assertEqual(sorted(obs), sorted(exp_trunc)) + pdt.assert_frame_equal(stats, exp_trunc_stats.loc[stats.index]) def test_q_score_real(self): self.maxDiff = None