Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] - Update rhythm code & docs #230

Merged
merged 16 commits into from
Jun 3, 2021
4 changes: 3 additions & 1 deletion neurodsp/plts/rhythm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@ def plot_swm_pattern(pattern, ax=None, **kwargs):
--------
Plot the average pattern from a sliding window matching analysis:

>>> import numpy as np
>>> from neurodsp.sim import sim_combined
>>> from neurodsp.rhythm import sliding_window_matching
>>> sig = sim_combined(n_seconds=10, fs=500,
... components={'sim_powerlaw': {'f_range': (2, None)},
... 'sim_bursty_oscillation': {'freq': 20,
... 'enter_burst': .25,
... 'leave_burst': .25}})
>>> avg_window, _, _ = sliding_window_matching(sig, fs=500, win_len=0.05, win_spacing=0.5)
>>> windows, _ = sliding_window_matching(sig, fs=500, win_len=0.05, win_spacing=0.5)
>>> avg_window = np.mean(windows)
>>> plot_swm_pattern(avg_window)
"""

Expand Down
32 changes: 22 additions & 10 deletions neurodsp/rhythm/lc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ def compute_lagged_coherence(sig, fs, freqs, n_cycles=3, return_spectrum=False):
fs : float
Sampling rate, in Hz.
freqs : 1d array or list of float
If array, frequency values to estimate with morlet wavelets.
The frequency values at which to estimate lagged coherence.
If array, defines the frequency values to use.
If list, define the frequency range, as [freq_start, freq_stop, freq_step].
The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value.
n_cycles : float or list of float, default: 3
Number of cycles of each frequency to use to compute lagged coherence.
If a single value, the same number of cycles is used for each frequency value.
If a list or list_like, then should be a n_cycles corresponding to each frequency.
The number of cycles to use to compute lagged coherence, for each frequency.
If a single value, the same number of cycles is used for each frequency.
If a list or list_like, there should be a value corresponding to each frequency.
return_spectrum : bool, optional, default: False
If True, return the lagged coherence for all frequency values.
Otherwise, only the mean lagged coherence value across the frequency range is returned.
Expand Down Expand Up @@ -87,7 +88,7 @@ def compute_lagged_coherence(sig, fs, freqs, n_cycles=3, return_spectrum=False):


def lagged_coherence_1freq(sig, fs, freq, n_cycles):
"""Compute the lagged coherence of a frequency using the hanning-taper FFT method.
"""Compute the lagged coherence at a particular frequency.

Parameters
----------
Expand All @@ -98,12 +99,17 @@ def lagged_coherence_1freq(sig, fs, freq, n_cycles):
freq : float
The frequency at which to estimate lagged coherence.
n_cycles : float
Number of cycles at the examined frequency to use to compute lagged coherence.
The number of cycles of the given frequency to use to compute lagged coherence.

Returns
-------
float
lc : float
The computed lagged coherence value.

Notes
-----
- Lagged coherence is computed using hanning-tapered FFTs.
- The returned lagged coherence value is bound between 0 and 1.
"""

# Determine number of samples to be used in each window to compute lagged coherence
Expand All @@ -113,20 +119,26 @@ def lagged_coherence_1freq(sig, fs, freq, n_cycles):
chunks = split_signal(sig, n_samps)
n_chunks = len(chunks)

# For each chunk, calculate the Fourier coefficients at the frequency of interest
# Create the window to apply to each chunk
hann_window = hann(n_samps)

# Create the frequency vector, finding the frequency value of interest
fft_freqs = np.fft.fftfreq(n_samps, 1 / float(fs))
fft_freqs_idx = np.argmin(np.abs(fft_freqs - freq))

# Calculate the Fourier coefficients across chunks for the frequency of interest
fft_coefs = np.zeros(n_chunks, dtype=complex)
for ind, chunk in enumerate(chunks):
fourier_coef = np.fft.fft(chunk * hann_window)
fft_coefs[ind] = fourier_coef[fft_freqs_idx]

# Compute the lagged coherence value
# Compute lagged coherence across data segments
lcs_num = 0
for ind in range(n_chunks - 1):
lcs_num += fft_coefs[ind] * np.conj(fft_coefs[ind + 1])
lcs_denom = np.sqrt(np.sum(np.abs(fft_coefs[:-1])**2) * np.sum(np.abs(fft_coefs[1:])**2))

return np.abs(lcs_num / lcs_denom)
# Normalize the lagged coherence value
lc_val = np.abs(lcs_num / lcs_denom)

return lc_val
208 changes: 117 additions & 91 deletions neurodsp/rhythm/swm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""The sliding window matching algorithm for identifying rhythmic components of a neural signal."""
"""The sliding window matching algorithm for identifying recurring patterns in a neural signal."""

import numpy as np

Expand All @@ -8,8 +8,8 @@
###################################################################################################

@multidim()
def sliding_window_matching(sig, fs, win_len, win_spacing, max_iterations=500,
temperature=1, window_starts_custom=None):
def sliding_window_matching(sig, fs, win_len, win_spacing, max_iterations=100,
window_starts_custom=None, var_thresh=None):
"""Find recurring patterns in a time series using the sliding window matching algorithm.

Parameters
Expand All @@ -19,150 +19,176 @@ def sliding_window_matching(sig, fs, win_len, win_spacing, max_iterations=500,
fs : float
Sampling rate, in Hz.
win_len : float
Window length, in seconds.
Window length, in seconds. This is L in the original paper.
win_spacing : float
Minimum window spacing, in seconds.
max_iterations : int
Minimum spacing between windows, in seconds. This is G in the original paper.
max_iterations : int, optional, default: 100
Maximum number of iterations of potential changes in window placement.
temperature : float
Temperature parameter. Controls probability of accepting a new window.
window_starts_custom : 1d array, optional
window_starts_custom : 1d array, optional, default: None
Custom pre-set locations of initial windows.
var_thresh: float, opational, default: None
Removes initial windows with variance below a set threshold. This speeds up
runtime proportional to the number of low variance windows in the data.

Returns
-------
avg_window : 1d array
The average waveform in 'sig' in the frequency 'f_range' triggered on 'trigger'.
windows : 2d array
Putative patterns discovered in the input signal.
window_starts : 1d array
Indices at which each window begins for the final set of windows.
costs : 1d array
Cost function value at each iteration.

References
----------
.. [1] Gips, B., Bahramisharif, A., Lowet, E., Roberts, M. J., de Weerd, P., Jensen, O., &
van der Eerden, J. (2017). Discovering recurring patterns in electrophysiological
recordings. Journal of Neuroscience Methods, 275, 66-79.
DOI: 10.1016/j.jneumeth.2016.11.001
Matlab Code: https://github.com/bartgips/SWM
.. [2] Matlab Code implementation: https://github.com/bartgips/SWM

Notes
-----
- Apply a highpass filter if looking at high frequency activity, so that it does
not converge on a low frequency motif.
- Parameters `win_len` and `win_spacing` should be chosen to be about the size of the
motif of interest, and the N derived should be about the number of occurrences.
- The `win_len` parameter should be chosen to be about the size of the motif of interest.
The larger this window size, the more likely the pattern to reflect slower patterns.
- The `win_spacing` parameter also determines the number of windows that are used.
- If looking at high frequency activity, you may want to apply a highpass filter,
so that the algorithm does not converge on a low frequency motif.
- This implementation is a minimal, modified version, as compared to the original
implementation in [2], which has more available options.
- This version has the following changes to speed up convergence:

1. Each iteration is similar to an epoch, randomly moving all windows in
random order. The original implementation randomly selects windows and
does not guarantee even resampling.
2. New window acceptance is determined via increased correlation coefficients
and reduced ivariance across windows.
3. Phase optimization / realignment to escape local minima.


Examples
--------
Search for reoccuring patterns using sliding window matching in a simulated beta signal:

>>> from neurodsp.sim import sim_combined
>>> sig = sim_combined(n_seconds=10, fs=500,
... components={'sim_powerlaw': {'f_range': (2, None)},
... 'sim_bursty_oscillation': {'freq': 20,
... 'enter_burst': .25,
... 'leave_burst': .25}})
>>> avg_window, window_starts, costs = sliding_window_matching(sig, fs=500, win_len=0.05,
... win_spacing=0.20)
>>> components={'sim_bursty_oscillation': {'freq': 20, 'phase':'min'},
... 'sim_powerlaw': {'f_range': (2, None)}}
>>> sig = sim_combined(10, fs=500, components=components, component_variances=(1, .05))
>>> windows, starts = sliding_window_matching(sig, fs=500, win_len=0.05, win_spacing=0.05,
... max_iterations=100, var_thresh=.5)
"""

# Compute window length and spacing in samples
win_n_samps = int(win_len * fs)
spacing_n_samps = int(win_spacing * fs)
win_len = int(win_len * fs)
win_spacing = int(win_spacing * fs)

# Initialize window positions
if window_starts_custom is None:
window_starts = np.arange(0, len(sig) - win_n_samps, 2 * spacing_n_samps)
window_starts = np.arange(0, len(sig) - win_len, win_spacing).astype(int)
else:
window_starts = window_starts_custom
n_windows = len(window_starts)

# Randomly sample windows with replacement
random_window_idx = np.random.choice(range(n_windows), size=max_iterations)
windows = np.array([sig[start:start+win_len] for start in window_starts])

# Calculate initial cost
costs = np.zeros(max_iterations)
costs[0] = _compute_cost(sig, window_starts, win_n_samps)
# New window bounds
lower_bounds, upper_bounds = _compute_bounds(window_starts, win_spacing, 0, len(sig) - win_len)

for iter_num in range(1, max_iterations):
# Remove low variance windows to speed up runtime
if var_thresh != None:

# Pick a random window position to randomly replace with a
# new window to improve cross-window similarity
window_idx_replace = random_window_idx[iter_num]
thresh = np.array([np.var(sig[loc:loc+win_len]) > var_thresh for loc in window_starts])

# Find a new allowed position for the window
window_starts_temp = np.copy(window_starts)
window_starts_temp[window_idx_replace] = _find_new_window_idx(
window_starts, spacing_n_samps, len(sig) - win_n_samps)
windows = windows[thresh]
window_starts = window_starts[thresh]
lower_bounds = lower_bounds[thresh]
upper_bounds = upper_bounds[thresh]

# Calculate the cost & the change in the cost function
cost_temp = _compute_cost(sig, window_starts_temp, win_n_samps)
delta_cost = cost_temp - costs[iter_num - 1]
# Modified SWM procedure
window_idxs = np.arange(len(windows)).astype(int)

# Calculate the acceptance probability
p_accept = np.exp(-delta_cost / float(temperature))
corrs, variance = _compute_cost(sig, window_starts, win_len)
mae = np.mean(np.abs(windows - windows.mean(axis=0)))

# Accept update to J with a certain probability
if np.random.rand() < p_accept:
for _ in range(max_iterations):

# Update costs & windows
costs[iter_num] = cost_temp
window_starts = window_starts_temp
# Randomly shuffle order of windows
np.random.shuffle(window_idxs)

else:
for win_idx in window_idxs:

# Update costs
costs[iter_num] = costs[iter_num - 1]
# Find a new, random window start
_window_starts = window_starts.copy()
_window_starts[win_idx] = np.random.choice(np.arange(lower_bounds[win_idx],
upper_bounds[win_idx]+1))

# Calculate average window
avg_window = np.zeros(win_n_samps)
for w_ind in range(n_windows):
avg_window = avg_window + sig[window_starts[w_ind]:window_starts[w_ind] + win_n_samps]
avg_window = avg_window / float(n_windows)
# Accept new window if correlation increases and variance decreases
_corrs, _variance = _compute_cost(sig, _window_starts, win_len)

return avg_window, window_starts, costs
if _corrs[win_idx].sum() > corrs[win_idx].sum() and _variance < variance:

corrs = _corrs.copy()
variance = _variance
window_starts = _window_starts.copy()
lower_bounds, upper_bounds = _compute_bounds(window_starts, win_spacing, 0, len(sig) - win_len)

def _compute_cost(sig, window_starts, win_n_samps):
"""Compute the cost, which is proportional to the difference between pairs of windows."""
# Phase optimization
_window_starts = window_starts.copy()

# Get all windows and z-score them
n_windows = len(window_starts)
windows = np.zeros((n_windows, win_n_samps))
for shift in np.arange(-int(win_len/2), int(win_len/2)):

for ind, window in enumerate(window_starts):
temp = sig[window:window_starts[ind] + win_n_samps]
windows[ind] = (temp - np.mean(temp)) / np.std(temp)
_starts = _window_starts + shift

# Calculate distances for all pairs of windows
dists = []
for ind1 in range(n_windows):
for ind2 in range(ind1 + 1, n_windows):
window_diff = windows[ind1] - windows[ind2]
dist_temp = np.sum(window_diff**2) / float(win_n_samps)
dists.append(dist_temp)
# Skip windows shifts that are out-of-bounds
if (_starts[0] < 0) or (_starts[-1] > len(sig) - win_len):
continue

# Calculate cost function, which is the average difference, roughly
cost = np.sum(dists) / float(2 * (n_windows - 1))
_windows = np.array([sig[start:start+win_len] for start in _starts])

return cost
_mae = np.mean(np.abs(_windows - _windows.mean(axis=0)))

if _mae < mae:
window_starts = _starts.copy()
windows = _windows.copy()
mae = _mae

def _find_new_window_idx(window_starts, spacing_n_samps, n_samp, tries_limit=1000):
"""Find a new sample for the starting window."""
lower_bounds, upper_bounds = _compute_bounds(window_starts, win_spacing, 0, len(sig) - win_len)

for n_try in range(tries_limit):
return windows, window_starts

# Generate a random sample & check how close it is to other window starts
new_samp = np.random.randint(n_samp)
dists = np.abs(window_starts - new_samp)

if np.min(dists) > spacing_n_samps:
break
def _compute_cost(sig, window_starts, win_n_samps, start=None, end=None):
"""Compute the cost, as corrleation coefficients and variance across windows.

else:
raise RuntimeError('SWM algorithm has difficulty finding a new window. \
Try increasing the spacing parameter.')
Parameters
----------
sig : 1d array
Time series.
window_starts : list of int
The list of window start definitions.
win_n_samps : int
The length of each window, in samples.

Returns
-------
corrs: 2d array
Window correlation matrix.
variance: float
Sum of the variance across windows.
"""

windows = np.array([sig[start:start+win_n_samps] for start in window_starts])

corrs = np.corrcoef(windows)

variance = windows.var(axis=1).sum()

return corrs, variance


def _compute_bounds(window_starts, win_spacing, start=None, end=None):

lower_bounds = window_starts[:-1] + win_spacing
lower_bounds = np.insert(lower_bounds, 0, start)

upper_bounds = window_starts[1:] - win_spacing
upper_bounds = np.insert(upper_bounds, len(upper_bounds), end)

return new_samp
return lower_bounds, upper_bounds
8 changes: 4 additions & 4 deletions neurodsp/tests/rhythm/test_swm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ def test_sliding_window_matching(tsig):

win_len, win_spacing = 1, 0.5

pattern, starts, costs = sliding_window_matching(tsig, FS, win_len, win_spacing)
assert pattern.shape[-1] == int(FS * win_len)
windows, starts = sliding_window_matching(tsig, FS, win_len, win_spacing, var_thresh=0.1)
assert windows.shape[-1] == int(FS * win_len)

def test_sliding_window_matching_2d(tsig2d):

win_len, win_spacing = 1, 0.5

pattern, starts, costs = sliding_window_matching(tsig2d, FS, win_len, win_spacing)
assert pattern.shape[-1] == int(FS * win_len)
windows, starts = sliding_window_matching(tsig2d, FS, win_len, win_spacing, var_thresh=0.1)
assert windows.shape[-1] == int(FS * win_len)
Loading