Skip to content

Commit

Permalink
example of reorganizing filter_signal arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Feb 23, 2024
1 parent 9675cd8 commit 8977cd9
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 52 deletions.
75 changes: 38 additions & 37 deletions neurodsp/filt/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
###################################################################################################

def filter_signal(sig, fs, pass_type, f_range, filter_type=None,
n_cycles=None, n_seconds=None, remove_edges=True, butterworth_order=None,
print_transitions=False, plot_properties=False, return_filter=False):
print_transitions=False, plot_properties=False, return_filter=False,
**filter_kwargs):
"""Apply a bandpass, bandstop, highpass, or lowpass filter to a neural signal.
Parameters
Expand All @@ -32,28 +32,31 @@ def filter_signal(sig, fs, pass_type, f_range, filter_type=None,
For 'bandpass' & 'bandstop', must be a tuple.
For 'lowpass' or 'highpass', can be a float that specifies pass frequency, or can be
a tuple and is assumed to be (None, f_hi) for 'lowpass', and (f_lo, None) for 'highpass'.
n_cycles : float, optional
Filter length, in number of cycles, defined at 'f_lo' frequency, if using an FIR filter.
Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both.
If not provided, and `n_seconds` is also not defined, defaults to 3.
n_seconds : float, optional
Filter length, in seconds, if using an FIR filter.
Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both.
filter_type : {'fir', 'iir'}, optional
Whether to use an FIR or IIR filter. IIR option is a butterworth filter.
If None, type is inferred from input parameters, and/or defaults to FIR.
remove_edges : bool, optional, default: True
If True, replace samples within half the kernel length to be np.nan.
Only used for FIR filters.
butterworth_order : int, optional
Order of the butterworth filter, if using an IIR filter.
See input 'N' in scipy.signal.butter.
print_transitions : bool, optional, default: True
If True, print out the transition and pass bandwidths.
plot_properties : bool, optional, default: False
If True, plot the properties of the filter, including frequency response and/or kernel.
return_filter : bool, optional, default: False
If True, return the filter coefficients.
**filter_kwargs
For FIR filters:
n_cycles : float, optional
Filter length, in number of cycles, defined at 'f_lo' frequency, if using an FIR filter.
Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both.
If not provided, and `n_seconds` is also not defined, defaults to 3.
n_seconds : float, optional
Filter length, in seconds, if using an FIR filter.
Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both.
remove_edges : bool, optional, default: True
If True, replace samples within half the kernel length to be np.nan.
Only used for FIR filters.
For IIR filters:
butterworth_order : int, optional
Order of the butterworth filter, if using an IIR filter.
See input 'N' in scipy.signal.butter.
Returns
-------
Expand All @@ -77,34 +80,32 @@ def filter_signal(sig, fs, pass_type, f_range, filter_type=None,
check_param_options(filter_type, 'filter_type', ['fir', 'iir'])
else:
# Infer IIR if relevant parameter set, otherwise, assume FIR
filter_type = 'iir' if butterworth_order is not None else 'fir'
filter_type = 'iir' if 'butterworth_order' in filter_kwargs else 'fir'

_filter_input_checks(filter_type, filter_kwargs)

if filter_type.lower() == 'fir':
_fir_checks(butterworth_order)
return filter_signal_fir(sig, fs, pass_type, f_range, n_cycles, n_seconds,
remove_edges, print_transitions,
plot_properties, return_filter)
return filter_signal_fir(sig, fs, pass_type, f_range, **filter_kwargs,
print_transitions=print_transitions,
plot_properties=plot_properties,
return_filter=return_filter)

elif filter_type.lower() == 'iir':
_iir_checks(n_seconds, butterworth_order, remove_edges)
return filter_signal_iir(sig, fs, pass_type, f_range, butterworth_order,
print_transitions, plot_properties,
return_filter)

return filter_signal_iir(sig, fs, pass_type, f_range, **filter_kwargs,
print_transitions=print_transitions,
plot_properties=plot_properties,
return_filter=return_filter)

def _fir_checks(butterworth_order):
"""Check inputs for using an FIR filter if called from the general filter function."""

if butterworth_order is not None:
raise ValueError('`butterworth_order` should not be defined when using an FIR filter.')
FILTER_INPUTS = {
'fir' : ['n_cycles', 'n_seconds', 'remove_edges'],
'iir' : ['butterworth_order'],
}


def _iir_checks(n_seconds, butterworth_order, remove_edges):
"""Check inputs for using an IIR filter if called from the general filter function."""
def _filter_input_checks(filter_type, filter_kwargs):
"""Check inputs to `filter_signal` match filter type."""

if n_seconds is not None:
raise ValueError('`n_seconds` should not be defined for an IIR filter.')
if butterworth_order is None:
raise ValueError('`butterworth_order` must be defined when using an IIR filter.')
if remove_edges:
warn('Edge artifacts are not removed when using an IIR filter.')
for param in filter_kwargs.keys():
assert param in FILTER_INPUTS[filter_type], \
'Parameter {} not expected for {} filter'.format(param, filter_type)
15 changes: 0 additions & 15 deletions neurodsp/tests/filt/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from neurodsp.tests.settings import FS, F_RANGE

from neurodsp.filt.filter import *
from neurodsp.filt.filter import _iir_checks

###################################################################################################
###################################################################################################
Expand All @@ -24,17 +23,3 @@ def test_filter_signal(tsig):
sigs = np.array([tsig, tsig])
outs = filter_signal(sigs, FS, 'bandpass', F_RANGE, remove_edges=False)
assert np.diff(outs, axis=0).sum() == 0

def test_iir_checks():

# Check catch for having n_seconds defined
with raises(ValueError):
_iir_checks(1, 3, None)

# Check catch for not having butterworth_order defined
with raises(ValueError):
_iir_checks(None, None, None)

# Check catch for having remove_edges defined
with warns(UserWarning):
_iir_checks(None, 3, True)

0 comments on commit 8977cd9

Please sign in to comment.