Skip to content

Commit

Permalink
Merge pull request #322 from neurodsp-tools/filtopt
Browse files Browse the repository at this point in the history
[MNT] - Update filter type checks & inference
  • Loading branch information
TomDonoghue authored Feb 28, 2024
2 parents 4e8ef4d + 020aff2 commit a7b4f8d
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 77 deletions.
87 changes: 50 additions & 37 deletions neurodsp/filt/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
###################################################################################################
###################################################################################################

def filter_signal(sig, fs, pass_type, f_range, filter_type='fir',
n_cycles=3, n_seconds=None, remove_edges=True, butterworth_order=None,
print_transitions=False, plot_properties=False, return_filter=False):
def filter_signal(sig, fs, pass_type, f_range, filter_type=None,
print_transitions=False, plot_properties=False, return_filter=False,
**filter_kwargs):
"""Apply a bandpass, bandstop, highpass, or lowpass filter to a neural signal.
Parameters
Expand All @@ -32,27 +32,31 @@ def filter_signal(sig, fs, pass_type, f_range, filter_type='fir',
For 'bandpass' & 'bandstop', must be a tuple.
For 'lowpass' or 'highpass', can be a float that specifies pass frequency, or can be
a tuple and is assumed to be (None, f_hi) for 'lowpass', and (f_lo, None) for 'highpass'.
n_cycles : float, optional, default: 3
Length of filter, in number of cycles, at the 'f_lo' frequency, if using an FIR filter.
This parameter is overwritten by `n_seconds`, if provided.
n_seconds : float, optional
Length of filter, in seconds, if using an FIR filter.
This parameter overwrites `n_cycles`.
filter_type : {'fir', 'iir'}, optional
Whether to use an FIR or IIR filter.
The only IIR filter offered is a butterworth filter.
remove_edges : bool, optional, default: True
If True, replace samples within half the kernel length to be np.nan.
Only used for FIR filters.
butterworth_order : int, optional
Order of the butterworth filter, if using an IIR filter.
See input 'N' in scipy.signal.butter.
Whether to use an FIR or IIR filter. IIR option is a butterworth filter.
If None, type is inferred from input parameters, and/or defaults to FIR.
print_transitions : bool, optional, default: True
If True, print out the transition and pass bandwidths.
plot_properties : bool, optional, default: False
If True, plot the properties of the filter, including frequency response and/or kernel.
return_filter : bool, optional, default: False
If True, return the filter coefficients.
**filter_kwargs
Additional parameters for the filtering function, specific to filtering type.
| For FIR filters, can include:
| n_cycles : float, optional
| Filter length, in number of cycles, defined at 'f_lo' frequency.
| Either `n_cycles` or `n_seconds` can be set for the filter length, but not both.
| If not provided, and `n_seconds` is also not defined, defaults to 3.
| n_seconds : float, optional
| Filter length, in seconds.
| Either `n_cycles` or `n_seconds` can be set for the filter length, but not both.
| remove_edges : bool, optional, default: True
| If True, replace samples within half the kernel length to be np.nan.
| For IIR filters, can include:
| butterworth_order : int, optional
| Order of the butterworth filter. See input 'N' in scipy.signal.butter.
Returns
-------
Expand All @@ -72,27 +76,36 @@ def filter_signal(sig, fs, pass_type, f_range, filter_type='fir',
... filter_type='fir', f_range=(1, 25))
"""

check_param_options(filter_type, 'filter_type', ['fir', 'iir'])
if filter_type is not None:
check_param_options(filter_type, 'filter_type', ['fir', 'iir'])
else:
# Infer IIR if relevant parameter set, otherwise, assume FIR
filter_type = 'iir' if 'butterworth_order' in filter_kwargs else 'fir'

_filter_input_checks(filter_type, filter_kwargs)

if filter_type.lower() == 'fir':
return filter_signal_fir(sig, fs, pass_type, f_range, n_cycles, n_seconds,
remove_edges, print_transitions,
plot_properties, return_filter)
return filter_signal_fir(sig, fs, pass_type, f_range, **filter_kwargs,
print_transitions=print_transitions,
plot_properties=plot_properties,
return_filter=return_filter)

elif filter_type.lower() == 'iir':
_iir_checks(n_seconds, butterworth_order, remove_edges)
return filter_signal_iir(sig, fs, pass_type, f_range, butterworth_order,
print_transitions, plot_properties,
return_filter)


def _iir_checks(n_seconds, butterworth_order, remove_edges):
"""Checks for using an IIR filter if called from the general filter function."""

# Check inputs for IIR filters
if n_seconds is not None:
raise ValueError('n_seconds should not be defined for an IIR filter.')
if butterworth_order is None:
raise ValueError('butterworth_order must be defined when using an IIR filter.')
if remove_edges:
warn('Edge artifacts are not removed when using an IIR filter.')
return filter_signal_iir(sig, fs, pass_type, f_range, **filter_kwargs,
print_transitions=print_transitions,
plot_properties=plot_properties,
return_filter=return_filter)


FILTER_INPUTS = {
'fir' : ['n_cycles', 'n_seconds', 'remove_edges'],
'iir' : ['butterworth_order'],
}


def _filter_input_checks(filter_type, filter_kwargs):
"""Check inputs to `filter_signal` match filter type."""

for param in filter_kwargs.keys():
assert param in FILTER_INPUTS[filter_type], \
'Parameter {} not expected for {} filter'.format(param, filter_type)
56 changes: 38 additions & 18 deletions neurodsp/filt/fir.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
###################################################################################################
###################################################################################################

def filter_signal_fir(sig, fs, pass_type, f_range, n_cycles=3, n_seconds=None, remove_edges=True,
print_transitions=False, plot_properties=False, return_filter=False):
def filter_signal_fir(sig, fs, pass_type, f_range, n_cycles=None, n_seconds=None,
remove_edges=True, print_transitions=False, plot_properties=False,
return_filter=False):
"""Apply an FIR filter to a signal.
Parameters
Expand All @@ -35,11 +36,13 @@ def filter_signal_fir(sig, fs, pass_type, f_range, n_cycles=3, n_seconds=None, r
For 'bandpass' & 'bandstop', must be a tuple.
For 'lowpass' or 'highpass', can be a float that specifies pass frequency, or can be
a tuple and is assumed to be (None, f_hi) for 'lowpass', and (f_lo, None) for 'highpass'.
n_cycles : float, optional, default: 3
Length of filter, in number of cycles, defined at the 'f_lo' frequency.
This parameter is overwritten by `n_seconds`, if provided.
n_cycles : float, optional
Filter length, in number of cycles, defined at the 'f_lo' frequency.
Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both.
If not provided, and `n_seconds` is also not defined, defaults to 3.
n_seconds : float, optional
Length of filter, in seconds. This parameter overwrites `n_cycles`.
Filter length, in seconds.
Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both.
remove_edges : bool, optional
If True, replace samples within half the kernel length to be np.nan.
print_transitions : bool, optional, default: False
Expand Down Expand Up @@ -134,7 +137,7 @@ def apply_fir_filter(sig, filter_coefs):
return convolve(sig, filter_coefs, mode='same')


def design_fir_filter(fs, pass_type, f_range, n_cycles=3, n_seconds=None):
def design_fir_filter(fs, pass_type, f_range, n_cycles=None, n_seconds=None):
"""Design an FIR filter.
Parameters
Expand All @@ -153,11 +156,13 @@ def design_fir_filter(fs, pass_type, f_range, n_cycles=3, n_seconds=None):
For 'bandpass' & 'bandstop', must be a tuple.
For 'lowpass' or 'highpass', can be a float that specifies pass frequency, or can be
a tuple and is assumed to be (None, f_hi) for 'lowpass', and (f_lo, None) for 'highpass'.
n_cycles : float, optional, default: 3
Length of filter, in number of cycles, defined at the 'f_lo' frequency.
This parameter is overwritten by `n_seconds`, if provided.
n_seconds : float or None, optional
Length of filter, in seconds. This parameter overwrites `n_cycles`.
n_cycles : float, optional
Filter length, in number of cycles, defined at the 'f_lo' frequency.
Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both.
If not provided, and `n_seconds` is also not defined, defaults to 3.
n_seconds : float, optional
Filter length, in seconds.
Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both.
Returns
-------
Expand All @@ -171,8 +176,12 @@ def design_fir_filter(fs, pass_type, f_range, n_cycles=3, n_seconds=None):
>>> filter_coefs = design_fir_filter(fs=500, pass_type='bandpass', f_range=(1, 25))
"""

# Check filter definition
f_lo, f_hi = check_filter_definition(pass_type, f_range)

# Default to a filter length of `n_cycles` of 3, if nothing otherwise set
if n_cycles is None and n_seconds is None:
n_cycles = 3

filt_len = compute_filter_length(fs, pass_type, f_lo, f_hi, n_cycles, n_seconds)

if pass_type == 'bandpass':
Expand Down Expand Up @@ -200,23 +209,34 @@ def compute_filter_length(fs, pass_type, f_lo, f_hi, n_cycles=None, n_seconds=No
The lower frequency range of the filter, specifying the highpass frequency, if specified.
f_hi : float or None
The higher frequency range of the filter, specifying the lowpass frequency, if specified.
n_cycles : float or None, optional
Length of filter, in number of cycles, defined at the 'f_lo' frequency.
n_seconds : float or None, optional
Length of filter, in seconds.
n_cycles : float, optional
Filter length, in number of cycles, defined at the 'f_lo' frequency.
Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both.
n_seconds : float, optional
Filter length, in seconds.
Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both.
Returns
-------
filt_len : int
The length of the specified filter.
Raises
------
ValueError
If both `n_cycles` & `n_seconds` are defined, leading to a conflict in filter length.
Examples
--------
Compute the length of bandpass (1 to 25 Hz) filter:
>>> filt_len = compute_filter_length(fs=500, pass_type='bandpass', f_lo=1, f_hi=25, n_cycles=3)
"""

# Check for a conflict in defining length from having both `n_cycles` & `n_seconds`
if n_cycles is not None and n_seconds is not None:
raise ValueError('Either `n_cycles` or `n_seconds` can be defined, but not both.')

# Compute filter length if specified in seconds
if n_seconds is not None:
filt_len = fs * n_seconds
Expand All @@ -229,7 +249,7 @@ def compute_filter_length(fs, pass_type, f_lo, f_hi, n_cycles=None, n_seconds=No
else:
raise ValueError('Either `n_cycles` or `n_seconds` needs to be defined.')

# Typecast filter length to an integer, rounding up & force length to be odd
# Typecast filter length to an integer, rounding up & forcing length to be odd
filt_len = int(np.ceil(filt_len))
if filt_len % 2 == 0:
filt_len = filt_len + 1
Expand Down
4 changes: 2 additions & 2 deletions neurodsp/tests/filt/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,15 @@ def test_check_filter_properties():

# Check failing filter - insufficient attenuation
with warnings.catch_warnings(record=True) as warn:
filter_coefs = design_fir_filter(FS, 'bandstop', (8, 12))
filter_coefs = design_fir_filter(FS, 'bandstop', (8, 12), n_cycles=3)
passes = check_filter_properties(filter_coefs, 1, FS, 'bandpass', (8, 12))
assert passes is False
assert len(warn) == 1
assert "filter attenuation" in str(warn[-1].message)

# Check failing filter - transition bandwidth
with warnings.catch_warnings(record=True) as warn:
filter_coefs = design_fir_filter(FS, 'bandpass', (20, 21))
filter_coefs = design_fir_filter(FS, 'bandpass', (20, 21), n_cycles=3)
passes = check_filter_properties(filter_coefs, 1, FS, 'bandpass', (8, 12))
assert passes is False
assert len(warn) == 1
Expand Down
30 changes: 19 additions & 11 deletions neurodsp/tests/filt/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from neurodsp.tests.settings import FS, F_RANGE

from neurodsp.filt.filter import *
from neurodsp.filt.filter import _iir_checks
from neurodsp.filt.filter import _filter_input_checks

###################################################################################################
###################################################################################################
Expand All @@ -25,16 +25,24 @@ def test_filter_signal(tsig):
outs = filter_signal(sigs, FS, 'bandpass', F_RANGE, remove_edges=False)
assert np.diff(outs, axis=0).sum() == 0

def test_iir_checks():
def test_filter_input_checks():

# Check catch for having n_seconds defined
with raises(ValueError):
_iir_checks(1, 3, None)
fir_inputs = {'n_cycles' : 5, 'remove_edges' : False}
_filter_input_checks('fir', fir_inputs)

# Check catch for not having butterworth_order defined
with raises(ValueError):
_iir_checks(None, None, None)
iir_inputs = {'butterworth_order' : 7}
_filter_input_checks('iir', iir_inputs)

# Check catch for having remove_edges defined
with warns(UserWarning):
_iir_checks(None, 3, True)
mixed_inputs = {'n_cycles' : 5, 'butterworth_order' : 7}
extra_inputs = {'n_cycles' : 5, 'nonsense_input' : True}

with raises(AssertionError):
_filter_input_checks('fir', iir_inputs)
with raises(AssertionError):
_filter_input_checks('iir', fir_inputs)
with raises(AssertionError):
_filter_input_checks('fir', mixed_inputs)
with raises(AssertionError):
_filter_input_checks('fir', mixed_inputs)
with raises(AssertionError):
_filter_input_checks('fir', extra_inputs)
15 changes: 8 additions & 7 deletions neurodsp/tests/filt/test_fir.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_design_fir_filter():
'lowpass' : (None, 5), 'highpass' : (5, None)}

for pass_type, f_range in test_filts.items():
filter_coefs = design_fir_filter(FS, pass_type, f_range)
filter_coefs = design_fir_filter(FS, pass_type, f_range, n_cycles=3)

def test_apply_fir_filter(tsig):

Expand All @@ -61,22 +61,23 @@ def test_compute_filter_length():
# n_seconds here is chosen to create expected odd filt_len, without needing rounding up
n_seconds = 1.75
expected_filt_len = n_seconds * fs
filt_len = compute_filter_length(fs, 'bandpass', f_lo, f_hi,
n_cycles=None, n_seconds=n_seconds)
filt_len = compute_filter_length(fs, 'bandpass', f_lo, f_hi, n_seconds=n_seconds)
assert filt_len == expected_filt_len

# Check filt_len, if defined using n_cycles
n_cycles = 5
expected_filt_len = int(np.ceil(fs * n_cycles / f_lo))
filt_len = compute_filter_length(fs, 'bandpass', f_lo, f_hi,
n_cycles=n_cycles, n_seconds=None)
filt_len = compute_filter_length(fs, 'bandpass', f_lo, f_hi, n_cycles=n_cycles)
assert filt_len == expected_filt_len

# Check filt_len, if expected to be rounded up to be odd
n_cycles = 4
expected_filt_len = int(np.ceil(fs * n_cycles / f_lo)) + 1
filt_len = compute_filter_length(fs, 'bandpass', f_lo, f_hi,
n_cycles=n_cycles, n_seconds=None)
filt_len = compute_filter_length(fs, 'bandpass', f_lo, f_hi, n_cycles=n_cycles)
assert filt_len == expected_filt_len
with raises(ValueError):
filt_len = compute_filter_length(fs, 'bandpass', f_lo, f_hi)

# Test error with inconsistent inputs
with raises(ValueError):
compute_filter_length(fs, 'bandpass', f_lo, f_hi, n_cycles=3, n_seconds=2.0)
4 changes: 2 additions & 2 deletions neurodsp/tests/filt/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_infer_passtype():

def test_compute_frequency_response():

filter_coefs = design_fir_filter(FS, 'bandpass', (8, 12))
filter_coefs = design_fir_filter(FS, 'bandpass', (8, 12), n_cycles=3)
f_db, db = compute_frequency_response(filter_coefs, 1, FS)

with raises(ValueError):
Expand All @@ -33,7 +33,7 @@ def test_compute_pass_band():

def test_compute_transition_band():

filter_coefs = design_fir_filter(FS, 'bandpass', (8, 12))
filter_coefs = design_fir_filter(FS, 'bandpass', (8, 12), n_cycles=3)
f_db, db = compute_frequency_response(filter_coefs, 1, FS)
trans_band = compute_transition_band(f_db, db)

Expand Down

0 comments on commit a7b4f8d

Please sign in to comment.