diff --git a/neurodsp/filt/filter.py b/neurodsp/filt/filter.py index 1495c338..f5b9eebf 100644 --- a/neurodsp/filt/filter.py +++ b/neurodsp/filt/filter.py @@ -9,9 +9,9 @@ ################################################################################################### ################################################################################################### -def filter_signal(sig, fs, pass_type, f_range, filter_type='fir', - n_cycles=3, n_seconds=None, remove_edges=True, butterworth_order=None, - print_transitions=False, plot_properties=False, return_filter=False): +def filter_signal(sig, fs, pass_type, f_range, filter_type=None, + print_transitions=False, plot_properties=False, return_filter=False, + **filter_kwargs): """Apply a bandpass, bandstop, highpass, or lowpass filter to a neural signal. Parameters @@ -32,27 +32,31 @@ def filter_signal(sig, fs, pass_type, f_range, filter_type='fir', For 'bandpass' & 'bandstop', must be a tuple. For 'lowpass' or 'highpass', can be a float that specifies pass frequency, or can be a tuple and is assumed to be (None, f_hi) for 'lowpass', and (f_lo, None) for 'highpass'. - n_cycles : float, optional, default: 3 - Length of filter, in number of cycles, at the 'f_lo' frequency, if using an FIR filter. - This parameter is overwritten by `n_seconds`, if provided. - n_seconds : float, optional - Length of filter, in seconds, if using an FIR filter. - This parameter overwrites `n_cycles`. filter_type : {'fir', 'iir'}, optional - Whether to use an FIR or IIR filter. - The only IIR filter offered is a butterworth filter. - remove_edges : bool, optional, default: True - If True, replace samples within half the kernel length to be np.nan. - Only used for FIR filters. - butterworth_order : int, optional - Order of the butterworth filter, if using an IIR filter. - See input 'N' in scipy.signal.butter. + Whether to use an FIR or IIR filter. IIR option is a butterworth filter. + If None, type is inferred from input parameters, and/or defaults to FIR. print_transitions : bool, optional, default: True If True, print out the transition and pass bandwidths. plot_properties : bool, optional, default: False If True, plot the properties of the filter, including frequency response and/or kernel. return_filter : bool, optional, default: False If True, return the filter coefficients. + **filter_kwargs + Additional parameters for the filtering function, specific to filtering type. + + | For FIR filters, can include: + | n_cycles : float, optional + | Filter length, in number of cycles, defined at 'f_lo' frequency. + | Either `n_cycles` or `n_seconds` can be set for the filter length, but not both. + | If not provided, and `n_seconds` is also not defined, defaults to 3. + | n_seconds : float, optional + | Filter length, in seconds. + | Either `n_cycles` or `n_seconds` can be set for the filter length, but not both. + | remove_edges : bool, optional, default: True + | If True, replace samples within half the kernel length to be np.nan. + | For IIR filters, can include: + | butterworth_order : int, optional + | Order of the butterworth filter. See input 'N' in scipy.signal.butter. Returns ------- @@ -72,27 +76,36 @@ def filter_signal(sig, fs, pass_type, f_range, filter_type='fir', ... filter_type='fir', f_range=(1, 25)) """ - check_param_options(filter_type, 'filter_type', ['fir', 'iir']) + if filter_type is not None: + check_param_options(filter_type, 'filter_type', ['fir', 'iir']) + else: + # Infer IIR if relevant parameter set, otherwise, assume FIR + filter_type = 'iir' if 'butterworth_order' in filter_kwargs else 'fir' + + _filter_input_checks(filter_type, filter_kwargs) if filter_type.lower() == 'fir': - return filter_signal_fir(sig, fs, pass_type, f_range, n_cycles, n_seconds, - remove_edges, print_transitions, - plot_properties, return_filter) + return filter_signal_fir(sig, fs, pass_type, f_range, **filter_kwargs, + print_transitions=print_transitions, + plot_properties=plot_properties, + return_filter=return_filter) elif filter_type.lower() == 'iir': - _iir_checks(n_seconds, butterworth_order, remove_edges) - return filter_signal_iir(sig, fs, pass_type, f_range, butterworth_order, - print_transitions, plot_properties, - return_filter) - - -def _iir_checks(n_seconds, butterworth_order, remove_edges): - """Checks for using an IIR filter if called from the general filter function.""" - - # Check inputs for IIR filters - if n_seconds is not None: - raise ValueError('n_seconds should not be defined for an IIR filter.') - if butterworth_order is None: - raise ValueError('butterworth_order must be defined when using an IIR filter.') - if remove_edges: - warn('Edge artifacts are not removed when using an IIR filter.') + return filter_signal_iir(sig, fs, pass_type, f_range, **filter_kwargs, + print_transitions=print_transitions, + plot_properties=plot_properties, + return_filter=return_filter) + + +FILTER_INPUTS = { + 'fir' : ['n_cycles', 'n_seconds', 'remove_edges'], + 'iir' : ['butterworth_order'], +} + + +def _filter_input_checks(filter_type, filter_kwargs): + """Check inputs to `filter_signal` match filter type.""" + + for param in filter_kwargs.keys(): + assert param in FILTER_INPUTS[filter_type], \ + 'Parameter {} not expected for {} filter'.format(param, filter_type) diff --git a/neurodsp/filt/fir.py b/neurodsp/filt/fir.py index 0b5f2cb7..0ffe3c90 100644 --- a/neurodsp/filt/fir.py +++ b/neurodsp/filt/fir.py @@ -13,8 +13,9 @@ ################################################################################################### ################################################################################################### -def filter_signal_fir(sig, fs, pass_type, f_range, n_cycles=3, n_seconds=None, remove_edges=True, - print_transitions=False, plot_properties=False, return_filter=False): +def filter_signal_fir(sig, fs, pass_type, f_range, n_cycles=None, n_seconds=None, + remove_edges=True, print_transitions=False, plot_properties=False, + return_filter=False): """Apply an FIR filter to a signal. Parameters @@ -35,11 +36,13 @@ def filter_signal_fir(sig, fs, pass_type, f_range, n_cycles=3, n_seconds=None, r For 'bandpass' & 'bandstop', must be a tuple. For 'lowpass' or 'highpass', can be a float that specifies pass frequency, or can be a tuple and is assumed to be (None, f_hi) for 'lowpass', and (f_lo, None) for 'highpass'. - n_cycles : float, optional, default: 3 - Length of filter, in number of cycles, defined at the 'f_lo' frequency. - This parameter is overwritten by `n_seconds`, if provided. + n_cycles : float, optional + Filter length, in number of cycles, defined at the 'f_lo' frequency. + Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both. + If not provided, and `n_seconds` is also not defined, defaults to 3. n_seconds : float, optional - Length of filter, in seconds. This parameter overwrites `n_cycles`. + Filter length, in seconds. + Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both. remove_edges : bool, optional If True, replace samples within half the kernel length to be np.nan. print_transitions : bool, optional, default: False @@ -134,7 +137,7 @@ def apply_fir_filter(sig, filter_coefs): return convolve(sig, filter_coefs, mode='same') -def design_fir_filter(fs, pass_type, f_range, n_cycles=3, n_seconds=None): +def design_fir_filter(fs, pass_type, f_range, n_cycles=None, n_seconds=None): """Design an FIR filter. Parameters @@ -153,11 +156,13 @@ def design_fir_filter(fs, pass_type, f_range, n_cycles=3, n_seconds=None): For 'bandpass' & 'bandstop', must be a tuple. For 'lowpass' or 'highpass', can be a float that specifies pass frequency, or can be a tuple and is assumed to be (None, f_hi) for 'lowpass', and (f_lo, None) for 'highpass'. - n_cycles : float, optional, default: 3 - Length of filter, in number of cycles, defined at the 'f_lo' frequency. - This parameter is overwritten by `n_seconds`, if provided. - n_seconds : float or None, optional - Length of filter, in seconds. This parameter overwrites `n_cycles`. + n_cycles : float, optional + Filter length, in number of cycles, defined at the 'f_lo' frequency. + Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both. + If not provided, and `n_seconds` is also not defined, defaults to 3. + n_seconds : float, optional + Filter length, in seconds. + Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both. Returns ------- @@ -171,8 +176,12 @@ def design_fir_filter(fs, pass_type, f_range, n_cycles=3, n_seconds=None): >>> filter_coefs = design_fir_filter(fs=500, pass_type='bandpass', f_range=(1, 25)) """ - # Check filter definition f_lo, f_hi = check_filter_definition(pass_type, f_range) + + # Default to a filter length of `n_cycles` of 3, if nothing otherwise set + if n_cycles is None and n_seconds is None: + n_cycles = 3 + filt_len = compute_filter_length(fs, pass_type, f_lo, f_hi, n_cycles, n_seconds) if pass_type == 'bandpass': @@ -200,16 +209,23 @@ def compute_filter_length(fs, pass_type, f_lo, f_hi, n_cycles=None, n_seconds=No The lower frequency range of the filter, specifying the highpass frequency, if specified. f_hi : float or None The higher frequency range of the filter, specifying the lowpass frequency, if specified. - n_cycles : float or None, optional - Length of filter, in number of cycles, defined at the 'f_lo' frequency. - n_seconds : float or None, optional - Length of filter, in seconds. + n_cycles : float, optional + Filter length, in number of cycles, defined at the 'f_lo' frequency. + Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both. + n_seconds : float, optional + Filter length, in seconds. + Either `n_cycles` or `n_seconds` can be defined to set the filter length, but not both. Returns ------- filt_len : int The length of the specified filter. + Raises + ------ + ValueError + If both `n_cycles` & `n_seconds` are defined, leading to a conflict in filter length. + Examples -------- Compute the length of bandpass (1 to 25 Hz) filter: @@ -217,6 +233,10 @@ def compute_filter_length(fs, pass_type, f_lo, f_hi, n_cycles=None, n_seconds=No >>> filt_len = compute_filter_length(fs=500, pass_type='bandpass', f_lo=1, f_hi=25, n_cycles=3) """ + # Check for a conflict in defining length from having both `n_cycles` & `n_seconds` + if n_cycles is not None and n_seconds is not None: + raise ValueError('Either `n_cycles` or `n_seconds` can be defined, but not both.') + # Compute filter length if specified in seconds if n_seconds is not None: filt_len = fs * n_seconds @@ -229,7 +249,7 @@ def compute_filter_length(fs, pass_type, f_lo, f_hi, n_cycles=None, n_seconds=No else: raise ValueError('Either `n_cycles` or `n_seconds` needs to be defined.') - # Typecast filter length to an integer, rounding up & force length to be odd + # Typecast filter length to an integer, rounding up & forcing length to be odd filt_len = int(np.ceil(filt_len)) if filt_len % 2 == 0: filt_len = filt_len + 1 diff --git a/neurodsp/tests/filt/test_checks.py b/neurodsp/tests/filt/test_checks.py index f039049d..3f9570a7 100644 --- a/neurodsp/tests/filt/test_checks.py +++ b/neurodsp/tests/filt/test_checks.py @@ -54,7 +54,7 @@ def test_check_filter_properties(): # Check failing filter - insufficient attenuation with warnings.catch_warnings(record=True) as warn: - filter_coefs = design_fir_filter(FS, 'bandstop', (8, 12)) + filter_coefs = design_fir_filter(FS, 'bandstop', (8, 12), n_cycles=3) passes = check_filter_properties(filter_coefs, 1, FS, 'bandpass', (8, 12)) assert passes is False assert len(warn) == 1 @@ -62,7 +62,7 @@ def test_check_filter_properties(): # Check failing filter - transition bandwidth with warnings.catch_warnings(record=True) as warn: - filter_coefs = design_fir_filter(FS, 'bandpass', (20, 21)) + filter_coefs = design_fir_filter(FS, 'bandpass', (20, 21), n_cycles=3) passes = check_filter_properties(filter_coefs, 1, FS, 'bandpass', (8, 12)) assert passes is False assert len(warn) == 1 diff --git a/neurodsp/tests/filt/test_filter.py b/neurodsp/tests/filt/test_filter.py index ee390b6c..5f35f2fc 100644 --- a/neurodsp/tests/filt/test_filter.py +++ b/neurodsp/tests/filt/test_filter.py @@ -7,7 +7,7 @@ from neurodsp.tests.settings import FS, F_RANGE from neurodsp.filt.filter import * -from neurodsp.filt.filter import _iir_checks +from neurodsp.filt.filter import _filter_input_checks ################################################################################################### ################################################################################################### @@ -25,16 +25,24 @@ def test_filter_signal(tsig): outs = filter_signal(sigs, FS, 'bandpass', F_RANGE, remove_edges=False) assert np.diff(outs, axis=0).sum() == 0 -def test_iir_checks(): +def test_filter_input_checks(): - # Check catch for having n_seconds defined - with raises(ValueError): - _iir_checks(1, 3, None) + fir_inputs = {'n_cycles' : 5, 'remove_edges' : False} + _filter_input_checks('fir', fir_inputs) - # Check catch for not having butterworth_order defined - with raises(ValueError): - _iir_checks(None, None, None) + iir_inputs = {'butterworth_order' : 7} + _filter_input_checks('iir', iir_inputs) - # Check catch for having remove_edges defined - with warns(UserWarning): - _iir_checks(None, 3, True) + mixed_inputs = {'n_cycles' : 5, 'butterworth_order' : 7} + extra_inputs = {'n_cycles' : 5, 'nonsense_input' : True} + + with raises(AssertionError): + _filter_input_checks('fir', iir_inputs) + with raises(AssertionError): + _filter_input_checks('iir', fir_inputs) + with raises(AssertionError): + _filter_input_checks('fir', mixed_inputs) + with raises(AssertionError): + _filter_input_checks('fir', mixed_inputs) + with raises(AssertionError): + _filter_input_checks('fir', extra_inputs) diff --git a/neurodsp/tests/filt/test_fir.py b/neurodsp/tests/filt/test_fir.py index 9da9a83f..6114df31 100644 --- a/neurodsp/tests/filt/test_fir.py +++ b/neurodsp/tests/filt/test_fir.py @@ -44,7 +44,7 @@ def test_design_fir_filter(): 'lowpass' : (None, 5), 'highpass' : (5, None)} for pass_type, f_range in test_filts.items(): - filter_coefs = design_fir_filter(FS, pass_type, f_range) + filter_coefs = design_fir_filter(FS, pass_type, f_range, n_cycles=3) def test_apply_fir_filter(tsig): @@ -61,22 +61,23 @@ def test_compute_filter_length(): # n_seconds here is chosen to create expected odd filt_len, without needing rounding up n_seconds = 1.75 expected_filt_len = n_seconds * fs - filt_len = compute_filter_length(fs, 'bandpass', f_lo, f_hi, - n_cycles=None, n_seconds=n_seconds) + filt_len = compute_filter_length(fs, 'bandpass', f_lo, f_hi, n_seconds=n_seconds) assert filt_len == expected_filt_len # Check filt_len, if defined using n_cycles n_cycles = 5 expected_filt_len = int(np.ceil(fs * n_cycles / f_lo)) - filt_len = compute_filter_length(fs, 'bandpass', f_lo, f_hi, - n_cycles=n_cycles, n_seconds=None) + filt_len = compute_filter_length(fs, 'bandpass', f_lo, f_hi, n_cycles=n_cycles) assert filt_len == expected_filt_len # Check filt_len, if expected to be rounded up to be odd n_cycles = 4 expected_filt_len = int(np.ceil(fs * n_cycles / f_lo)) + 1 - filt_len = compute_filter_length(fs, 'bandpass', f_lo, f_hi, - n_cycles=n_cycles, n_seconds=None) + filt_len = compute_filter_length(fs, 'bandpass', f_lo, f_hi, n_cycles=n_cycles) assert filt_len == expected_filt_len with raises(ValueError): filt_len = compute_filter_length(fs, 'bandpass', f_lo, f_hi) + + # Test error with inconsistent inputs + with raises(ValueError): + compute_filter_length(fs, 'bandpass', f_lo, f_hi, n_cycles=3, n_seconds=2.0) diff --git a/neurodsp/tests/filt/test_utils.py b/neurodsp/tests/filt/test_utils.py index 44539d10..d25887e9 100644 --- a/neurodsp/tests/filt/test_utils.py +++ b/neurodsp/tests/filt/test_utils.py @@ -19,7 +19,7 @@ def test_infer_passtype(): def test_compute_frequency_response(): - filter_coefs = design_fir_filter(FS, 'bandpass', (8, 12)) + filter_coefs = design_fir_filter(FS, 'bandpass', (8, 12), n_cycles=3) f_db, db = compute_frequency_response(filter_coefs, 1, FS) with raises(ValueError): @@ -33,7 +33,7 @@ def test_compute_pass_band(): def test_compute_transition_band(): - filter_coefs = design_fir_filter(FS, 'bandpass', (8, 12)) + filter_coefs = design_fir_filter(FS, 'bandpass', (8, 12), n_cycles=3) f_db, db = compute_frequency_response(filter_coefs, 1, FS) trans_band = compute_transition_band(f_db, db)