diff --git a/doc/api.rst b/doc/api.rst index d34b0fe6..e76e040b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -430,6 +430,7 @@ Spectral :toctree: generated/ plot_power_spectra + plot_spectra_3d plot_scv plot_scv_rs_lines plot_scv_rs_matrix @@ -465,6 +466,15 @@ Time Frequency plot_timefrequency +Aperiodic +~~~~~~~~~ + +.. currentmodule:: neurodsp.plts +.. autosummary:: + :toctree: generated/ + + plot_autocorr + Combined ~~~~~~~~ diff --git a/neurodsp/plts/__init__.py b/neurodsp/plts/__init__.py index f0fb58bf..5ce0b10d 100644 --- a/neurodsp/plts/__init__.py +++ b/neurodsp/plts/__init__.py @@ -4,7 +4,8 @@ plot_multi_time_series) from .filt import plot_filter_properties, plot_frequency_response, plot_impulse_response from .rhythm import plot_swm_pattern, plot_lagged_coherence -from .spectral import (plot_power_spectra, plot_spectral_hist, +from .spectral import (plot_power_spectra, plot_spectral_hist, plot_spectra_3d, plot_scv, plot_scv_rs_lines, plot_scv_rs_matrix) from .timefrequency import plot_timefrequency +from .aperiodic import plot_autocorr from .combined import plot_timeseries_and_spectra diff --git a/neurodsp/plts/aperiodic.py b/neurodsp/plts/aperiodic.py new file mode 100644 index 00000000..5bc602fa --- /dev/null +++ b/neurodsp/plts/aperiodic.py @@ -0,0 +1,35 @@ +"""Plotting functions for neurodsp.aperiodic.""" + +from neurodsp.plts.style import style_plot +from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot + +#################################################################################################### +#################################################################################################### + +@savefig +@style_plot +def plot_autocorr(timepoints, autocorrs, labels=None, colors=None, ax=None, **kwargs): + """Plot autocorrelation results. + + Parameters + ---------- + timepoints : 1d array + Time points, in samples, at which autocorrelations are computed. + autocorrs : array + Autocorrelation values, across time lags. + labels : str or list of str, optional + Labels for each time series. + colors : str or list of str + Colors to use to plot lines. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **kwargs + Keyword arguments for customizing the plot. + """ + + ax = check_ax(ax, figsize=kwargs.pop('figsize', (6, 5))) + + for time, ac, label, color in zip(*prepare_multi_plot(timepoints, autocorrs, labels, colors)): + ax.plot(time, ac, label=label, color=color) + + ax.set(xlabel='Lag (Samples)', ylabel='Autocorrelation') diff --git a/neurodsp/plts/spectral.py b/neurodsp/plts/spectral.py index c1ddfdad..02cf77db 100644 --- a/neurodsp/plts/spectral.py +++ b/neurodsp/plts/spectral.py @@ -1,12 +1,10 @@ """Plotting functions for neurodsp.spectral.""" -from itertools import repeat, cycle - import numpy as np import matplotlib.pyplot as plt from neurodsp.plts.style import style_plot -from neurodsp.plts.utils import check_ax, savefig +from neurodsp.plts.utils import check_ax, check_ax_3d, savefig, prepare_multi_plot ################################################################################################### ################################################################################################### @@ -47,18 +45,8 @@ def plot_power_spectra(freqs, powers, labels=None, colors=None, ax=None, **kwarg ax = check_ax(ax, figsize=kwargs.pop('figsize', (6, 6))) - freqs = repeat(freqs) if isinstance(freqs, np.ndarray) and freqs.ndim == 1 else freqs - powers = [powers] if isinstance(powers, np.ndarray) and powers.ndim == 1 else powers - - if labels is not None: - labels = [labels] if not isinstance(labels, list) else labels - else: - labels = repeat(labels) - - colors = repeat(colors) if not isinstance(colors, list) else cycle(colors) - - for freq, power, color, label in zip(freqs, powers, colors, labels): - ax.loglog(freq, power, color=color, label=label) + for freq, power, label, color in zip(*prepare_multi_plot(freqs, powers, labels, colors)): + ax.loglog(freq, power, label=label, color=color) ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Power ($V^2/Hz$)') @@ -235,3 +223,73 @@ def plot_spectral_hist(freqs, power_bins, spectral_hist, spectrum_freqs=None, if spectrum is not None: plt_inds = np.logical_and(spectrum_freqs >= freqs[0], spectrum_freqs <= freqs[-1]) ax.plot(spectrum_freqs[plt_inds], np.log10(spectrum[plt_inds]), color='w', alpha=0.8) + + +@savefig +@style_plot +def plot_spectra_3d(freqs, powers, log_freqs=False, log_powers=True, colors=None, + orientation=(20, -50), zoom=1.0, ax=None, **kwargs): + """Plot a series of power spectra in a 3D plot. + + Parameters + ---------- + freqs : 1d or 2d array or list of 1d array + Frequency vector. + powers : 2d array or list of 1d array + Power values. + log_freqs : bool, optional, default: False + Whether to plot the frequency values in log10 space. + log_powers : bool, optional, default: True + Whether to plot the power values in log10 space. + colors : str or list of str, optional + Colors to use to plot lines. + orientation : tuple of int, optional, default: (20, -50) + Orientation to set the 3D plot. See `Axes3D.view_init` for more information. + zoom : float, optional, default: 1.0 + Zoom scaling for the figure axis. See `Axes3D.set_box_aspect` for more information. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. Must be a 3D axis. + **kwargs + Keyword arguments for customizing the plot. + + Examples + -------- + Plot power spectra in 3D: + + >>> from neurodsp.sim import sim_combined + >>> from neurodsp.spectral import compute_spectrum + >>> sig1 = sim_combined(n_seconds=10, fs=500, + ... components={'sim_powerlaw': {'exponent' : -1}, + ... 'sim_bursty_oscillation' : {'freq': 10}}) + >>> sig2 = sim_combined(n_seconds=10, fs=500, + ... components={'sim_powerlaw': {'exponent' : -1.5}, + ... 'sim_bursty_oscillation' : {'freq': 10}}) + >>> freqs1, powers1 = compute_spectrum(sig1, fs=500) + >>> freqs2, powers2 = compute_spectrum(sig2, fs=500) + >>> plot_spectra_3d([freqs1, freqs2], [powers1, powers2]) + """ + + ax = check_ax_3d(ax) + + n_spectra = len(powers) + + for ind, (freq, power, _, color) in \ + enumerate(zip(*prepare_multi_plot(freqs, powers, None, colors))): + ax.plot(xs=np.log10(freq) if log_freqs else freq, + ys=[ind] * len(freq), + zs=np.log10(power) if log_powers else power, + color=color, + **kwargs) + + ax.set( + xlabel='Frequency (Hz)', + ylabel='Channels', + zlabel='Power', + ylim=[0, n_spectra - 1], + ) + + yticks = list(range(n_spectra)) + ax.set_yticks(yticks, yticks) + + ax.view_init(*orientation) + ax.set_box_aspect(None, zoom=zoom) diff --git a/neurodsp/plts/style.py b/neurodsp/plts/style.py index e71506c8..ddd37bb2 100644 --- a/neurodsp/plts/style.py +++ b/neurodsp/plts/style.py @@ -113,10 +113,12 @@ def apply_custom_style(ax, **kwargs): if ax.get_title(): ax.title.set_size(kwargs.pop('title_fontsize', TITLE_FONTSIZE)) - # Settings for the axis labels + # Settings for the axis labels, including checking & setting for 3D axis label_size = kwargs.pop('label_size', LABEL_SIZE) ax.xaxis.label.set_size(label_size) ax.yaxis.label.set_size(label_size) + if hasattr(ax, 'zaxis'): + ax.zaxis.label.set_size(label_size) # Settings for the axis ticks ax.tick_params(axis='both', which='major', diff --git a/neurodsp/plts/time_series.py b/neurodsp/plts/time_series.py index 114fb2b5..1db61142 100644 --- a/neurodsp/plts/time_series.py +++ b/neurodsp/plts/time_series.py @@ -1,12 +1,12 @@ """Plots for time series.""" -from itertools import repeat, cycle +from itertools import repeat import numpy as np import matplotlib.pyplot as plt from neurodsp.plts.style import style_plot -from neurodsp.plts.utils import check_ax, savefig +from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot from neurodsp.utils.data import create_samples from neurodsp.utils.checks import check_param_options @@ -49,18 +49,12 @@ def plot_time_series(times, sigs, labels=None, colors=None, ax=None, **kwargs): ax = check_ax(ax, kwargs.pop('figsize', (15, 3))) - sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs times, xlabel = _check_times(times, sigs) - - if labels is not None: - labels = [labels] if not isinstance(labels, list) else labels - else: - labels = repeat(labels) + times, sigs, colors, labels = prepare_multi_plot(times, sigs, colors, labels) # If not provided, default colors for up to two signals to be black & red - if not colors and len(sigs) <= 2: + if isinstance(colors, repeat) and next(colors) is None and len(sigs) <= 2: colors = ['k', 'r'] - colors = repeat(colors) if not isinstance(colors, list) else cycle(colors) for time, sig, color, label in zip(times, sigs, colors, labels): ax.plot(time, sig, color=color, label=label) @@ -174,22 +168,19 @@ def plot_multi_time_series(times, sigs, colors=None, ax=None, **plt_kwargs): Keyword arguments for customizing the plot. """ - colors = 'black' if not colors else colors - colors = repeat(colors) if isinstance(colors, str) else iter(colors) - ax = check_ax(ax, figsize=plt_kwargs.pop('figsize', (15, 5))) - sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs + colors = 'black' if not colors else colors + times, xlabel = _check_times(times, sigs) + times, sigs, _, colors = prepare_multi_plot(times, sigs, None, colors) step = 0.8 * np.ptp(sigs[0]) for ind, (time, sig) in enumerate(zip(times, sigs)): ax.plot(time, sig+step*ind, color=next(colors), **plt_kwargs) - ax.set(yticks=[]) - ax.set_xlabel(xlabel) - ax.set_ylabel('Channels') + ax.set(xlabel=xlabel, ylabel='Channels', yticks=[]) def _check_times(times, sigs): @@ -197,9 +188,11 @@ def _check_times(times, sigs): xlabel = 'Time (s)' if times is None: - times = create_samples(len(sigs[0])) + if isinstance(sigs, list) or (isinstance(sigs, np.ndarray) and sigs.ndim == 2): + n_samples = len(sigs[0]) + else: + n_samples = len(sigs) + times = create_samples(n_samples) xlabel = 'Samples' - times = repeat(times) if (isinstance(times, np.ndarray) and times.ndim == 1) else times - return times, xlabel diff --git a/neurodsp/plts/utils.py b/neurodsp/plts/utils.py index 553791bb..aed83cf2 100644 --- a/neurodsp/plts/utils.py +++ b/neurodsp/plts/utils.py @@ -3,7 +3,9 @@ from copy import deepcopy from functools import wraps from os.path import join as pjoin +from itertools import repeat, cycle +import numpy as np import matplotlib.pyplot as plt from neurodsp.plts.settings import SUPTITLE_FONTSIZE @@ -60,6 +62,36 @@ def check_ax(ax, figsize=None): return ax +def check_ax_3d(ax, figsize=None): + """Check whether a 3D figure axes object is defined, define if not. + + Parameters + ---------- + ax : matplotlib.Axes or None + Axes object to check if is defined. Must be 3D. + + Returns + ------- + ax : matplotlib.Axes + Figure axes object to use. + + Raises + ------ + ValueError + If the ax input is a defined axis, but is not 3D. + """ + + if ax and '3d' not in ax.name: + raise ValueError('Provided axis is not 3D.') + + if not ax: + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(projection='3d') + + return ax + + def savefig(func): """Decorator function to save out figures.""" @@ -155,3 +187,46 @@ def make_axes(n_rows, n_cols, figsize=None, row_size=4, col_size=3.6, **title_kwargs) return axes + + +def prepare_multi_plot(xs, ys, labels=None, colors=None): + """Prepare inputs for plotting one or more elements in a loop. + + Parameters + ---------- + xs, ys : 1d or 2d array + Plot data. + labels : str or list + Label(s) for the plot input(s). + colors : str or iterable + Color(s) to plot input(s). + + Returns + ------- + xs, ys : iterable + Plot data. + labels : iterable + Label(s) for the plot input(s). + colors : iterable + Color(s) to plot input(s). + + Notes + ----- + This function takes inputs that can reflect one or more plot elements, and + prepares the inputs to be iterable for plotting in a loop. + """ + + xs = repeat(xs) if isinstance(xs, np.ndarray) and xs.ndim == 1 else xs + ys = [ys] if isinstance(ys, np.ndarray) and ys.ndim == 1 else ys + + # Collect definition of collection items considered iterables to check against + iterables = (list, tuple, np.ndarray) + + if labels is not None: + labels = [labels] if not isinstance(labels, iterables) else labels + else: + labels = repeat(labels) + + colors = repeat(colors) if not isinstance(colors, iterables) else cycle(colors) + + return xs, ys, labels, colors diff --git a/neurodsp/tests/plts/test_aperiodic.py b/neurodsp/tests/plts/test_aperiodic.py new file mode 100644 index 00000000..cfa675f7 --- /dev/null +++ b/neurodsp/tests/plts/test_aperiodic.py @@ -0,0 +1,25 @@ +"""Tests for neurodsp.plts.aperiodic.""" + +from neurodsp.aperiodic.autocorr import compute_autocorr + +from neurodsp.tests.settings import TEST_PLOTS_PATH, FS +from neurodsp.tests.tutils import plot_test + +from neurodsp.plts.aperiodic import * + +################################################################################################### +################################################################################################### + +def tests_plot_autocorr(tsig, tsig_comb): + + times1, acs1 = compute_autocorr(tsig, max_lag=150) + times2, acs2 = compute_autocorr(tsig_comb, max_lag=150) + + plot_autocorr(times1, acs1, + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_autocorr-1.png') + + plot_autocorr([times1, times2], [acs1, acs2], + labels=['first', 'second'], colors=['k', 'r'], + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_autocorr-2.png') diff --git a/neurodsp/tests/plts/test_spectral.py b/neurodsp/tests/plts/test_spectral.py index 049ce8bb..fab0107c 100644 --- a/neurodsp/tests/plts/test_spectral.py +++ b/neurodsp/tests/plts/test_spectral.py @@ -68,3 +68,18 @@ def test_plot_spectral_hist(tsig_comb): spectrum=spectrum, spectrum_freqs=spectrum_freqs, save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectral_hist.png') + +@plot_test +def test_plot_spectra_3d(tsig_comb, tsig_burst): + + freqs1, powers1 = compute_spectrum(tsig_comb, FS) + freqs2, powers2 = compute_spectrum(tsig_burst, FS) + + plot_spectra_3d([freqs1, freqs2], [powers1, powers2], + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectral3D_1.png') + + plot_spectra_3d(freqs1, [powers1, powers2, powers1, powers2], + colors=['r', 'y', 'b', 'g'], + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectral3D_2.png') diff --git a/neurodsp/tests/plts/test_utils.py b/neurodsp/tests/plts/test_utils.py index f727c42e..527ff18d 100644 --- a/neurodsp/tests/plts/test_utils.py +++ b/neurodsp/tests/plts/test_utils.py @@ -1,7 +1,11 @@ """Tests for neurodsp.plts.utils.""" +from pytest import raises + import os +import itertools +import numpy as np import matplotlib as mpl from neurodsp.tests.settings import TEST_PLOTS_PATH @@ -40,6 +44,22 @@ def test_check_ax(): fig = plt.gcf() assert list(fig.get_size_inches()) == figsize +def test_check_ax_3d(): + + # Check running with None Input + ax = check_ax(None) + + # Check error if given a non 3D axis + with raises(ValueError): + _, ax = plt.subplots() + nax = check_ax_3d(ax) + + # Check running with pre-created axis + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + nax = check_ax(ax) + assert nax == ax + def test_savefig(): @savefig @@ -74,3 +94,28 @@ def test_make_axes(): axes = make_axes(2, 2) assert axes.shape == (2, 2) assert isinstance(axes[0, 0], mpl.axes._axes.Axes) + +def test_prepare_multi_plot(): + + xs1 = np.array([1, 2, 3]) + ys1 = np.array([1, 2, 3]) + labels1 = None + colors1 = None + + # 1 input + xs1o, ys1o, labels1o, colors1o = prepare_multi_plot(xs1, ys1, labels1, colors1) + assert isinstance(xs1o, itertools.repeat) + assert isinstance(ys1o, list) + assert isinstance(labels1o, itertools.repeat) + assert isinstance(colors1o, itertools.repeat) + + # multiple inputs + xs2 = [np.array([1, 2, 3]), np.array([4, 5, 6])] + ys2 = [np.array([1, 2, 3]), np.array([4, 5, 6])] + labels2 = ['A', 'B'] + colors2 = ['blue', 'red'] + xs2o, ys2o, labels2o, colors2o = prepare_multi_plot(xs2, ys2, labels2, colors2) + assert isinstance(xs2o, list) + assert isinstance(ys2o, list) + assert isinstance(labels2o, list) + assert isinstance(colors2o, itertools.cycle) diff --git a/tutorials/aperiodic/plot_Autocorr.py b/tutorials/aperiodic/plot_Autocorr.py index 8adc6580..83b9da29 100644 --- a/tutorials/aperiodic/plot_Autocorr.py +++ b/tutorials/aperiodic/plot_Autocorr.py @@ -15,12 +15,12 @@ # sphinx_gallery_thumbnail_number = 1 import numpy as np -import matplotlib.pyplot as plt from neurodsp.sim import sim_powerlaw, sim_oscillation # Import the function for computing autocorrelation from neurodsp.aperiodic import compute_autocorr +from neurodsp.plts import plot_autocorr ################################################################################################### # Autocorrelation Measures @@ -92,9 +92,7 @@ ################################################################################################### # Plot autocorrelations -_, ax = plt.subplots(figsize=(6, 4)) -ax.plot(timepoints_osc1, autocorrs_osc1) -ax.set(xlabel='lag (samples)', ylabel='autocorrelation'); +plot_autocorr(timepoints_osc1, autocorrs_osc1) ################################################################################################### # @@ -109,11 +107,9 @@ ################################################################################################### # Plot autocorrelations for two different sinusoids -_, ax = plt.subplots(figsize=(6, 4)) -ax.plot(timepoints_osc1, autocorrs_osc1, alpha=0.75, label='10 Hz') -ax.plot(timepoints_osc2, autocorrs_osc2, alpha=0.75, label='20 Hz') -ax.set(xlabel='lag (samples)', ylabel='autocorrelation') -plt.legend(loc='upper right') +plot_autocorr([timepoints_osc1, timepoints_osc2], + [autocorrs_osc1, autocorrs_osc2], + labels=['10 Hz', '20 Hz']) ################################################################################################### # @@ -152,11 +148,9 @@ ################################################################################################### # Plot the autocorrelations of the aperiodic signals -_, ax = plt.subplots(figsize=(5, 4)) -ax.plot(timepoints_wn, autocorrs_wn, label='White Noise') -ax.plot(timepoints_pn, autocorrs_pn, label='Pink Noise') -ax.set(xlabel="lag (samples)", ylabel="autocorrelation") -plt.legend() +plot_autocorr([timepoints_wn, timepoints_pn], + [autocorrs_wn, autocorrs_pn], + labels=['White Noise', 'Pink Noise']) ################################################################################################### # @@ -165,4 +159,4 @@ # # By comparison, the pink noise signal has a pattern of decreasing autocorrelation # across increasing lags. This is characteristic of powerlaw data. -# \ No newline at end of file +#