Skip to content

Commit

Permalink
add make_axes helper
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Apr 14, 2024
1 parent d3abb38 commit b4e176a
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 0 deletions.
1 change: 1 addition & 0 deletions neurodsp/plts/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

## Define default values for aesthetics
# These are all custom style arguments
SUPTITLE_FONTSIZE = 24
TITLE_FONTSIZE = 20
LABEL_SIZE = 16
TICK_LABELSIZE = 16
Expand Down
78 changes: 78 additions & 0 deletions neurodsp/plts/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,45 @@
"""Utility functions for plots."""

from copy import deepcopy
from functools import wraps
from os.path import join as pjoin

import matplotlib.pyplot as plt

from neurodsp.plts.settings import SUPTITLE_FONTSIZE

###################################################################################################
###################################################################################################

def subset_kwargs(kwargs, label):
"""Subset a set of kwargs from a dictionary.
Parameters
----------
kwargs : dict
Dictionary of keyword arguments.
label : str
Label to use to subset.
Any entries with label in the key will be subset from the kwargs dict.
Returns
-------
kwargs : dict
The kwargs dictionary, with subset items removed.
subset : dict
The collection of subset kwargs.
"""

kwargs = deepcopy(kwargs)

subset = {}
for key in list(kwargs.keys()):
if label in key:
subset[key] = kwargs.pop(key)

return kwargs, subset


def check_ax(ax, figsize=None):
"""Check whether a figure axes object is defined, define if not.
Expand Down Expand Up @@ -77,3 +109,49 @@ def save_figure(file_name, file_path=None, close=False, **save_kwargs):

if close:
plt.close()


def make_axes(n_rows, n_cols, figsize=None, row_size=4, col_size=3.6,
wspace=None, hspace=None, title=None, **plt_kwargs):
"""Make a subplot with multiple axes.
Parameters
----------
n_rows, n_cols : int
The number of rows and columns axes to create in the figure.
figsize : tuple of float, optional
Size to make the overall figure.
If not given, is estimated from the number of axes.
row_size, col_size : float, optional
The size to use per row / column.
Only used if `figsize` is None.
wspace, hspace : float, optional
Parameters for spacing between subplots.
These get passed into `plt.subplots_adjust`.
title : str, optional
A super title to add to the figure.
**plt_kwargs
Extra arguments to pass to `plt.subplots`.
Returns
-------
axes : 1d array of AxesSubplot
Collection of axes objects.
"""

if not figsize:
figsize = (n_cols * col_size, n_rows * row_size)

plt_kwargs, title_kwargs = subset_kwargs(plt_kwargs, 'title')

_, axes = plt.subplots(n_rows, n_cols, figsize=figsize, **plt_kwargs)

if wspace or hspace:
plt.subplots_adjust(wspace=wspace, hspace=hspace)

if title:
plt.suptitle(title,
fontsize=title_kwargs.pop('title_fontsize', SUPTITLE_FONTSIZE),
**title_kwargs)

return axes
21 changes: 21 additions & 0 deletions neurodsp/tests/plts/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,28 @@

import os

import matplotlib as mpl

from neurodsp.tests.settings import TEST_PLOTS_PATH

from neurodsp.plts.utils import *

###################################################################################################
###################################################################################################

def test_subset_kwargs():

kwargs = {'xlim' : [0, 10], 'ylim' : [2, 5],
'title_fontsize' : 24, 'title_fontweight': 'bold'}

kwargs1, subset1 = subset_kwargs(kwargs, 'lim')
assert list(kwargs1.keys()) == ['title_fontsize', 'title_fontweight']
assert list(subset1.keys()) == ['xlim', 'ylim']

kwargs2, subset2 = subset_kwargs(kwargs, 'title')
assert list(kwargs2.keys()) == ['xlim', 'ylim']
assert list(subset2.keys()) == ['title_fontsize', 'title_fontweight']

def test_check_ax():

# Check running with None Input
Expand Down Expand Up @@ -53,3 +68,9 @@ def test_save_figure():
plt.plot([1, 2], [3, 4])
save_figure(file_name='test_save_figure.pdf', file_path=TEST_PLOTS_PATH)
assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_save_figure.pdf'))

def test_make_axes():

axes = make_axes(2, 2)
assert axes.shape == (2, 2)
assert isinstance(axes[0, 0], mpl.axes._axes.Axes)

0 comments on commit b4e176a

Please sign in to comment.