Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] - Add a helper function for making multiple axis layout #330

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions neurodsp/plts/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

## Define default values for aesthetics
# These are all custom style arguments
SUPTITLE_FONTSIZE = 24
TITLE_FONTSIZE = 20
LABEL_SIZE = 16
TICK_LABELSIZE = 16
Expand Down
78 changes: 78 additions & 0 deletions neurodsp/plts/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,45 @@
"""Utility functions for plots."""

from copy import deepcopy
from functools import wraps
from os.path import join as pjoin

import matplotlib.pyplot as plt

from neurodsp.plts.settings import SUPTITLE_FONTSIZE

###################################################################################################
###################################################################################################

def subset_kwargs(kwargs, label):
"""Subset a set of kwargs from a dictionary.

Parameters
----------
kwargs : dict
Dictionary of keyword arguments.
label : str
Label to use to subset.
Any entries with label in the key will be subset from the kwargs dict.

Returns
-------
kwargs : dict
The kwargs dictionary, with subset items removed.
subset : dict
The collection of subset kwargs.
"""

kwargs = deepcopy(kwargs)

subset = {}
for key in list(kwargs.keys()):
if label in key:
subset[key] = kwargs.pop(key)

return kwargs, subset


def check_ax(ax, figsize=None):
"""Check whether a figure axes object is defined, define if not.

Expand Down Expand Up @@ -77,3 +109,49 @@ def save_figure(file_name, file_path=None, close=False, **save_kwargs):

if close:
plt.close()


def make_axes(n_rows, n_cols, figsize=None, row_size=4, col_size=3.6,
wspace=None, hspace=None, title=None, **plt_kwargs):
"""Make a subplot with multiple axes.

Parameters
----------
n_rows, n_cols : int
The number of rows and columns axes to create in the figure.
figsize : tuple of float, optional
Size to make the overall figure.
If not given, is estimated from the number of axes.
row_size, col_size : float, optional
The size to use per row / column.
Only used if `figsize` is None.
wspace, hspace : float, optional
Parameters for spacing between subplots.
These get passed into `plt.subplots_adjust`.
title : str, optional
A super title to add to the figure.
**plt_kwargs
Extra arguments to pass to `plt.subplots`.

Returns
-------
axes : 1d array of AxesSubplot
Collection of axes objects.
"""

if not figsize:
figsize = (n_cols * col_size, n_rows * row_size)

plt_kwargs, title_kwargs = subset_kwargs(plt_kwargs, 'title')

_, axes = plt.subplots(n_rows, n_cols, figsize=figsize, **plt_kwargs)

if wspace or hspace:
plt.subplots_adjust(wspace=wspace, hspace=hspace)

if title:
plt.suptitle(title,
fontsize=title_kwargs.pop('title_fontsize', SUPTITLE_FONTSIZE),
**title_kwargs)

return axes
21 changes: 21 additions & 0 deletions neurodsp/tests/plts/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,28 @@

import os

import matplotlib as mpl

from neurodsp.tests.settings import TEST_PLOTS_PATH

from neurodsp.plts.utils import *

###################################################################################################
###################################################################################################

def test_subset_kwargs():

kwargs = {'xlim' : [0, 10], 'ylim' : [2, 5],
'title_fontsize' : 24, 'title_fontweight': 'bold'}

kwargs1, subset1 = subset_kwargs(kwargs, 'lim')
assert list(kwargs1.keys()) == ['title_fontsize', 'title_fontweight']
assert list(subset1.keys()) == ['xlim', 'ylim']

kwargs2, subset2 = subset_kwargs(kwargs, 'title')
assert list(kwargs2.keys()) == ['xlim', 'ylim']
assert list(subset2.keys()) == ['title_fontsize', 'title_fontweight']

def test_check_ax():

# Check running with None Input
Expand Down Expand Up @@ -53,3 +68,9 @@ def test_save_figure():
plt.plot([1, 2], [3, 4])
save_figure(file_name='test_save_figure.pdf', file_path=TEST_PLOTS_PATH)
assert os.path.exists(os.path.join(TEST_PLOTS_PATH, 'test_save_figure.pdf'))

def test_make_axes():

axes = make_axes(2, 2)
assert axes.shape == (2, 2)
assert isinstance(axes[0, 0], mpl.axes._axes.Axes)
Loading