Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] - Add multitaper method for spectral density estimation #317

Merged
merged 15 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ Spectral Power
compute_spectrum_welch
compute_spectrum_wavelet
compute_spectrum_medfilt
compute_spectrum_multitaper

Spectral Measures
~~~~~~~~~~~~~~~~~
Expand Down
4 changes: 2 additions & 2 deletions neurodsp/spectral/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Spectral module, for calculating power spectra, spectral variance, etc."""

from .power import (compute_spectrum, compute_spectrum_welch,
compute_spectrum_wavelet, compute_spectrum_medfilt)
from .power import (compute_spectrum, compute_spectrum_welch, compute_spectrum_wavelet,
compute_spectrum_medfilt, compute_spectrum_multitaper)
from .measures import compute_absolute_power, compute_relative_power, compute_band_ratio
from .variance import compute_scv, compute_scv_rs, compute_spectral_hist
from .utils import trim_spectrum, trim_spectrogram
Expand Down
42 changes: 42 additions & 0 deletions neurodsp/spectral/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,45 @@ def check_spg_settings(fs, window, nperseg, noverlap):
noverlap = int(noverlap)

return nperseg, noverlap


def check_mt_settings(n_samples, fs, bandwidth, n_tapers):
"""Check settings used for computing spectra using the multitaper method.

Parameters
----------
n_samples : int
Number of samples in the signal.
fs : float
Sampling rate, in Hz.
bandwidth : float or None
Bandwidth of the multitaper window, in Hz. If None, will use
8 * fs / n_samples.
n_tapers : int or None
Number of tapers to use. If None, will use bandwidth * n_samples / fs

Returns
-------
nw : float
Standardized half bandwidth (used to compute DPSS)
n_tapers : int
Number of tapers.
"""

# set bandwidth
if bandwidth is None:
bandwidth = 8 * fs / n_samples # MNE default

# check bandwidth - break if alpha < 1
alpha = n_samples * bandwidth / (fs * 2)
if alpha < 1:
raise ValueError("Bandwidth too narrow for signal length and sampling rate. Try increasing bandwidth. n_samples * bandwidth / (fs * 2) must be >1")

# compute nw
nw = bandwidth * n_samples / (fs * 2)

# compute max number of DPSS tapers
if n_tapers is None:
n_tapers = int(2 * nw)

return nw, n_tapers
90 changes: 87 additions & 3 deletions neurodsp/spectral/power.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from neurodsp.utils.outliers import discard_outliers
from neurodsp.timefrequency.wavelets import compute_wavelet_transform
from neurodsp.spectral.utils import trim_spectrum
from neurodsp.spectral.checks import check_spg_settings
from neurodsp.spectral.checks import check_spg_settings, check_mt_settings

###################################################################################################
###################################################################################################
Expand All @@ -30,7 +30,7 @@ def compute_spectrum(sig, fs, method='welch', **kwargs):
Time series.
fs : float
Sampling rate, in Hz.
method : {'welch', 'wavelet', 'medfilt'}, optional
method : {'welch', 'wavelet', 'medfilt', 'multitaper'}, optional
Method to use to estimate the power spectrum.
**kwargs
Keyword arguments to pass through to the function that calculates the spectrum.
Expand All @@ -53,7 +53,7 @@ def compute_spectrum(sig, fs, method='welch', **kwargs):
>>> freqs, spectrum = compute_spectrum(sig, fs=500)
"""

check_param_options(method, 'method', ['welch', 'wavelet', 'medfilt'])
check_param_options(method, 'method', ['welch', 'wavelet', 'medfilt', 'multitaper'])
_spectrum_input_checks(method, kwargs)

if method == 'welch':
Expand All @@ -65,6 +65,9 @@ def compute_spectrum(sig, fs, method='welch', **kwargs):
elif method == 'medfilt':
return compute_spectrum_medfilt(sig, fs, **kwargs)

elif method == 'multitaper':
return compute_spectrum_multitaper(sig, fs, **kwargs)


SPECTRUM_INPUTS = {
'welch' : ['avg_type', 'window', 'nperseg', 'noverlap', 'f_range', 'outlier_percent'],
Expand Down Expand Up @@ -256,3 +259,84 @@ def compute_spectrum_medfilt(sig, fs, filt_len=1., f_range=None):
freqs, spectrum = trim_spectrum(freqs, spectrum, f_range)

return freqs, spectrum


def compute_spectrum_multitaper(sig, fs, bandwidth=None, n_tapers=None,
low_bias=True, eigenvalue_weighting=True):
"""Compute the power spectral density using the multi-taper method.

Parameters
----------
sig : 1d or 2d array
Time series.
fs : float
Sampling rate, in Hz.
bandwidth : float, optional
Frequency bandwidth of multi-taper window function. Default is
8 * fs / n_samples.
n_tapers : int, optional
Number of slepian windows used to compute the spectrum. Default is
bandwidth * n_samples / fs.
low_bias : bool, optional
If True, only use tapers with concentration ratio > 0.9. Default is
True.
eigenvalue_weighting : bool, optional
If True, weight spectral estimates by the concentration ratio of
their respective tapers before combining. Default is True.

Returns
-------
freqs : 1d array
Frequencies at which the measure was calculated.
spectrum : 1d or 2d array
Power spectral density using multi-taper method.

Examples
--------
Compute the power spectrum of a simulated time series using the
multitaper method:

>>> from neurodsp.sim import sim_combined
>>> sig = sim_combined(n_seconds=10, fs=500,
... components={'sim_powerlaw': {}, 'sim_oscillation' : {'freq': 10}})
>>> freqs, spec = compute_spectrum_multitaper(sig, fs=500)
"""

from scipy.signal.windows import dpss

# Compute signal length based on input shape
sig_len = sig.shape[sig.ndim - 1]

# check settings
nw, n_tapers = check_mt_settings(sig_len, fs, bandwidth, n_tapers)

# Create slepian sequences
slepian_sequences, ratios = dpss(sig_len, nw, n_tapers,
return_ratios=True)

# Drop tapers with low concentration
if low_bias:
slepian_sequences = slepian_sequences[ratios > 0.9]
ratios = ratios[ratios > 0.9]
if len(slepian_sequences) == 0:
raise ValueError('No tapers with concentration ratio > 0.9. Could not compute spectrum with low_bias=True.')

# Compute fourier on signal weighted by each slepian sequence
freqs = np.fft.rfftfreq(sig_len, 1. /fs)
spectra = np.abs(np.fft.rfft(slepian_sequences[:, np.newaxis]*sig))**2

# combine estimates to compute final spectrum
if eigenvalue_weighting:
# weight estimates by concentration ratios and combine
spectra_weighted = spectra * ratios[:, np.newaxis, np.newaxis]
spectrum = np.sum(spectra_weighted, axis=0) / np.sum(ratios)

else:
# Average spectral estimates
spectrum = spectra.mean(axis=0)

# Convert output to 1d if necessary
if sig.ndim == 1:
spectrum = spectrum[0]

return freqs, spectrum
29 changes: 23 additions & 6 deletions neurodsp/tests/spectral/test_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,8 @@ def test_compute_spectrum(tsig):
freqs, spectrum = compute_spectrum(tsig, FS, method='medfilt')
assert freqs.shape == spectrum.shape


SPECTRUM_INPUTS = {
'welch' : ['avg_type', 'window', 'nperseg', 'noverlap', 'f_range', 'outlier_percent'],
'wavelet' : ['freqs', 'avg_type', 'n_cycles', 'scaling', 'norm'],
'medfilt' : ['filt_len', 'f_range'],
}
freqs, spectrum = compute_spectrum(tsig, FS, method='multitaper')
assert freqs.shape == spectrum.shape

def test_spectrum_input_checks():

Expand All @@ -55,6 +51,10 @@ def test_compute_spectrum_2d(tsig2d):
assert freqs.shape[-1] == spectrum.shape[-1]
assert spectrum.ndim == 2

freqs, spectrum = compute_spectrum(tsig2d, FS, method='multitaper')
assert freqs.shape[-1] == spectrum.shape[-1]
assert spectrum.ndim == 2

def test_compute_spectrum_welch(tsig, tsig_sine):

freqs, spectrum = compute_spectrum_welch(tsig, FS, avg_type='mean')
Expand Down Expand Up @@ -104,3 +104,20 @@ def test_compute_spectrum_medfilt(tsig, tsig_sine):
# Therefore, it should match the estimate of psd from above
_, psd_medfilt = compute_spectrum(tsig_sine, FS, method='medfilt', filt_len=0.1)
assert np.allclose(psd, psd_medfilt, atol=EPS)

def test_compute_spectrum_multitaper(tsig_sine, tsig2d):
# Shape test: 1D input
freqs, spectrum = compute_spectrum_multitaper(tsig_sine, FS)
assert freqs.shape == spectrum.shape

# Shape test: 2D input
freqs_2d, spectrum_2d = compute_spectrum_multitaper(tsig2d, FS)
assert spectrum_2d.ndim == 2
assert spectrum_2d.shape[0] == tsig2d.shape[0]
assert spectrum_2d.shape[1] == len(freqs_2d)

# Accuracy test: peak at sine frequency
idx_freq_sine = np.argmin(np.abs(freqs - FREQ_SINE))
idx_peak = np.argmax(spectrum)
assert idx_freq_sine == idx_peak

Loading