From 8977cd94c11a05306cf1100500544c76bee17f3d Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 22 Feb 2024 23:53:14 -0500 Subject: [PATCH] example of reorganizing filter_signal arguments --- neurodsp/filt/filter.py | 75 +++++++++++++++--------------- neurodsp/tests/filt/test_filter.py | 15 ------ 2 files changed, 38 insertions(+), 52 deletions(-) diff --git a/neurodsp/filt/filter.py b/neurodsp/filt/filter.py index 29fc788c..75016152 100644 --- a/neurodsp/filt/filter.py +++ b/neurodsp/filt/filter.py @@ -10,8 +10,8 @@ ################################################################################################### def filter_signal(sig, fs, pass_type, f_range, filter_type=None, - n_cycles=None, n_seconds=None, remove_edges=True, butterworth_order=None, - print_transitions=False, plot_properties=False, return_filter=False): + print_transitions=False, plot_properties=False, return_filter=False, + **filter_kwargs): """Apply a bandpass, bandstop, highpass, or lowpass filter to a neural signal. Parameters @@ -32,28 +32,31 @@ def filter_signal(sig, fs, pass_type, f_range, filter_type=None, For 'bandpass' & 'bandstop', must be a tuple. For 'lowpass' or 'highpass', can be a float that specifies pass frequency, or can be a tuple and is assumed to be (None, f_hi) for 'lowpass', and (f_lo, None) for 'highpass'. - n_cycles : float, optional - Filter length, in number of cycles, defined at 'f_lo' frequency, if using an FIR filter. - Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both. - If not provided, and `n_seconds` is also not defined, defaults to 3. - n_seconds : float, optional - Filter length, in seconds, if using an FIR filter. - Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both. filter_type : {'fir', 'iir'}, optional Whether to use an FIR or IIR filter. IIR option is a butterworth filter. If None, type is inferred from input parameters, and/or defaults to FIR. - remove_edges : bool, optional, default: True - If True, replace samples within half the kernel length to be np.nan. - Only used for FIR filters. - butterworth_order : int, optional - Order of the butterworth filter, if using an IIR filter. - See input 'N' in scipy.signal.butter. print_transitions : bool, optional, default: True If True, print out the transition and pass bandwidths. plot_properties : bool, optional, default: False If True, plot the properties of the filter, including frequency response and/or kernel. return_filter : bool, optional, default: False If True, return the filter coefficients. + **filter_kwargs + For FIR filters: + n_cycles : float, optional + Filter length, in number of cycles, defined at 'f_lo' frequency, if using an FIR filter. + Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both. + If not provided, and `n_seconds` is also not defined, defaults to 3. + n_seconds : float, optional + Filter length, in seconds, if using an FIR filter. + Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both. + remove_edges : bool, optional, default: True + If True, replace samples within half the kernel length to be np.nan. + Only used for FIR filters. + For IIR filters: + butterworth_order : int, optional + Order of the butterworth filter, if using an IIR filter. + See input 'N' in scipy.signal.butter. Returns ------- @@ -77,34 +80,32 @@ def filter_signal(sig, fs, pass_type, f_range, filter_type=None, check_param_options(filter_type, 'filter_type', ['fir', 'iir']) else: # Infer IIR if relevant parameter set, otherwise, assume FIR - filter_type = 'iir' if butterworth_order is not None else 'fir' + filter_type = 'iir' if 'butterworth_order' in filter_kwargs else 'fir' + + _filter_input_checks(filter_type, filter_kwargs) if filter_type.lower() == 'fir': - _fir_checks(butterworth_order) - return filter_signal_fir(sig, fs, pass_type, f_range, n_cycles, n_seconds, - remove_edges, print_transitions, - plot_properties, return_filter) + return filter_signal_fir(sig, fs, pass_type, f_range, **filter_kwargs, + print_transitions=print_transitions, + plot_properties=plot_properties, + return_filter=return_filter) elif filter_type.lower() == 'iir': - _iir_checks(n_seconds, butterworth_order, remove_edges) - return filter_signal_iir(sig, fs, pass_type, f_range, butterworth_order, - print_transitions, plot_properties, - return_filter) - + return filter_signal_iir(sig, fs, pass_type, f_range, **filter_kwargs, + print_transitions=print_transitions, + plot_properties=plot_properties, + return_filter=return_filter) -def _fir_checks(butterworth_order): - """Check inputs for using an FIR filter if called from the general filter function.""" - if butterworth_order is not None: - raise ValueError('`butterworth_order` should not be defined when using an FIR filter.') +FILTER_INPUTS = { + 'fir' : ['n_cycles', 'n_seconds', 'remove_edges'], + 'iir' : ['butterworth_order'], +} -def _iir_checks(n_seconds, butterworth_order, remove_edges): - """Check inputs for using an IIR filter if called from the general filter function.""" +def _filter_input_checks(filter_type, filter_kwargs): + """Check inputs to `filter_signal` match filter type.""" - if n_seconds is not None: - raise ValueError('`n_seconds` should not be defined for an IIR filter.') - if butterworth_order is None: - raise ValueError('`butterworth_order` must be defined when using an IIR filter.') - if remove_edges: - warn('Edge artifacts are not removed when using an IIR filter.') + for param in filter_kwargs.keys(): + assert param in FILTER_INPUTS[filter_type], \ + 'Parameter {} not expected for {} filter'.format(param, filter_type) diff --git a/neurodsp/tests/filt/test_filter.py b/neurodsp/tests/filt/test_filter.py index ee390b6c..cf9b83ca 100644 --- a/neurodsp/tests/filt/test_filter.py +++ b/neurodsp/tests/filt/test_filter.py @@ -7,7 +7,6 @@ from neurodsp.tests.settings import FS, F_RANGE from neurodsp.filt.filter import * -from neurodsp.filt.filter import _iir_checks ################################################################################################### ################################################################################################### @@ -24,17 +23,3 @@ def test_filter_signal(tsig): sigs = np.array([tsig, tsig]) outs = filter_signal(sigs, FS, 'bandpass', F_RANGE, remove_edges=False) assert np.diff(outs, axis=0).sum() == 0 - -def test_iir_checks(): - - # Check catch for having n_seconds defined - with raises(ValueError): - _iir_checks(1, 3, None) - - # Check catch for not having butterworth_order defined - with raises(ValueError): - _iir_checks(None, None, None) - - # Check catch for having remove_edges defined - with warns(UserWarning): - _iir_checks(None, 3, True)