From b4e176a9bf0c6f823d2d2aa6b465cb1b6263ae73 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 14 Apr 2024 12:27:16 -0400 Subject: [PATCH] add make_axes helper --- neurodsp/plts/settings.py | 1 + neurodsp/plts/utils.py | 78 +++++++++++++++++++++++++++++++ neurodsp/tests/plts/test_utils.py | 21 +++++++++ 3 files changed, 100 insertions(+) diff --git a/neurodsp/plts/settings.py b/neurodsp/plts/settings.py index 3bf1888d..c4ff7990 100644 --- a/neurodsp/plts/settings.py +++ b/neurodsp/plts/settings.py @@ -29,6 +29,7 @@ ## Define default values for aesthetics # These are all custom style arguments +SUPTITLE_FONTSIZE = 24 TITLE_FONTSIZE = 20 LABEL_SIZE = 16 TICK_LABELSIZE = 16 diff --git a/neurodsp/plts/utils.py b/neurodsp/plts/utils.py index 79b30d7e..553791bb 100644 --- a/neurodsp/plts/utils.py +++ b/neurodsp/plts/utils.py @@ -1,13 +1,45 @@ """Utility functions for plots.""" +from copy import deepcopy from functools import wraps from os.path import join as pjoin import matplotlib.pyplot as plt +from neurodsp.plts.settings import SUPTITLE_FONTSIZE + ################################################################################################### ################################################################################################### +def subset_kwargs(kwargs, label): + """Subset a set of kwargs from a dictionary. + + Parameters + ---------- + kwargs : dict + Dictionary of keyword arguments. + label : str + Label to use to subset. + Any entries with label in the key will be subset from the kwargs dict. + + Returns + ------- + kwargs : dict + The kwargs dictionary, with subset items removed. + subset : dict + The collection of subset kwargs. + """ + + kwargs = deepcopy(kwargs) + + subset = {} + for key in list(kwargs.keys()): + if label in key: + subset[key] = kwargs.pop(key) + + return kwargs, subset + + def check_ax(ax, figsize=None): """Check whether a figure axes object is defined, define if not. @@ -77,3 +109,49 @@ def save_figure(file_name, file_path=None, close=False, **save_kwargs): if close: plt.close() + + +def make_axes(n_rows, n_cols, figsize=None, row_size=4, col_size=3.6, + wspace=None, hspace=None, title=None, **plt_kwargs): + """Make a subplot with multiple axes. + + Parameters + ---------- + n_rows, n_cols : int + The number of rows and columns axes to create in the figure. + figsize : tuple of float, optional + Size to make the overall figure. + If not given, is estimated from the number of axes. + row_size, col_size : float, optional + The size to use per row / column. + Only used if `figsize` is None. + wspace, hspace : float, optional + Parameters for spacing between subplots. + These get passed into `plt.subplots_adjust`. + title : str, optional + A super title to add to the figure. + **plt_kwargs + Extra arguments to pass to `plt.subplots`. + + Returns + ------- + axes : 1d array of AxesSubplot + Collection of axes objects. + """ + + if not figsize: + figsize = (n_cols * col_size, n_rows * row_size) + + plt_kwargs, title_kwargs = subset_kwargs(plt_kwargs, 'title') + + _, axes = plt.subplots(n_rows, n_cols, figsize=figsize, **plt_kwargs) + + if wspace or hspace: + plt.subplots_adjust(wspace=wspace, hspace=hspace) + + if title: + plt.suptitle(title, + fontsize=title_kwargs.pop('title_fontsize', SUPTITLE_FONTSIZE), + **title_kwargs) + + return axes diff --git a/neurodsp/tests/plts/test_utils.py b/neurodsp/tests/plts/test_utils.py index 77cc85f0..f727c42e 100644 --- a/neurodsp/tests/plts/test_utils.py +++ b/neurodsp/tests/plts/test_utils.py @@ -2,6 +2,8 @@ import os +import matplotlib as mpl + from neurodsp.tests.settings import TEST_PLOTS_PATH from neurodsp.plts.utils import * @@ -9,6 +11,19 @@ ################################################################################################### ################################################################################################### +def test_subset_kwargs(): + + kwargs = {'xlim' : [0, 10], 'ylim' : [2, 5], + 'title_fontsize' : 24, 'title_fontweight': 'bold'} + + kwargs1, subset1 = subset_kwargs(kwargs, 'lim') + assert list(kwargs1.keys()) == ['title_fontsize', 'title_fontweight'] + assert list(subset1.keys()) == ['xlim', 'ylim'] + + kwargs2, subset2 = subset_kwargs(kwargs, 'title') + assert list(kwargs2.keys()) == ['xlim', 'ylim'] + assert list(subset2.keys()) == ['title_fontsize', 'title_fontweight'] + def test_check_ax(): # Check running with None Input @@ -53,3 +68,9 @@ def test_save_figure(): plt.plot([1, 2], [3, 4]) save_figure(file_name='test_save_figure.pdf', file_path=TEST_PLOTS_PATH) assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_save_figure.pdf')) + +def test_make_axes(): + + axes = make_axes(2, 2) + assert axes.shape == (2, 2) + assert isinstance(axes[0, 0], mpl.axes._axes.Axes)