Skip to content

Commit

Permalink
Added named_arrays.plt.stairs() function, a thin wrapper around `ma…
Browse files Browse the repository at this point in the history
…tplotlib.pyplot.stairs()`. (#91)
  • Loading branch information
byrdie authored Nov 4, 2024
1 parent deb1c80 commit 8c706e8
Show file tree
Hide file tree
Showing 4 changed files with 358 additions and 0 deletions.
103 changes: 103 additions & 0 deletions named_arrays/_scalars/scalar_named_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,109 @@ def plt_scatter(
return result


@_implements(na.plt.stairs)
def plt_stairs(
*args: na.AbstractScalarArray,
ax: None | matplotlib.axes.Axes | na.ScalarArray[npt.NDArray[matplotlib.axes.Axes]] = None,
axis: None | str = None,
where: bool | na.AbstractScalarArray = True,
**kwargs,
) -> na.ScalarArray[npt.NDArray[None | matplotlib.artist.Artist]]:

if len(args) == 1:
edges = None
values, = args
elif len(args) == 2:
edges, values = args
else: # pragma: nocover
raise ValueError(
f"incorrect number of arguments, expected 1 or 2, got {len(args)}"
)

try:
values = scalars._normalize(values)
edges = scalars._normalize(edges) if edges is not None else edges
where = scalars._normalize(where)
kwargs = {k: scalars._normalize(kwargs[k]) for k in kwargs}
except na.ScalarTypeError: # pragma: nocover
return NotImplemented

if ax is None:
ax = plt.gca()
ax = na.as_named_array(ax)

if axis is None:
if len(values.shape) != 1:
raise ValueError(
f"if {axis=}, {values.shape=} should have only one element."
)
axis = next(iter(values.shape))
else:
if axis not in values.shape:
raise ValueError(
f"{axis=} must be an element of {values.shape}"
)

shape_values = na.shape(values)
shape_edges = na.shape(edges)
shape_args = na.broadcast_shapes(
shape_values,
{a: shape_edges[a] for a in shape_edges if a != axis},
)

shape = na.broadcast_shapes(ax.shape, shape_args)

shape_orthogonal = {a: shape[a] for a in shape if a != axis}

values = na.broadcast_to(values, shape)

if edges is not None:
edges = na.broadcast_to(
array=edges,
shape=shape_orthogonal | {axis: shape[axis] + 1},
)

if not set(ax.shape).issubset(shape_orthogonal): # pragma: nocover
raise ValueError(
f"the shape of `ax`, {ax.shape}, "
f"should be a subset of the broadcasted shape of `*args` excluding `axis`, {shape_orthogonal}",
)
ax = ax.broadcast_to(shape_orthogonal)

if not set(where.shape).issubset(shape_orthogonal): # pragma: nocover
raise ValueError(
f"the shape of `where`, {where.shape}, "
f"should be a subset of the broadcasted shape of `*args` excluding `axis`, {shape_orthogonal}"
)
where = where.broadcast_to(shape_orthogonal)

kwargs_broadcasted = dict()
for k in kwargs:
kwarg = kwargs[k]
if not set(na.shape(kwarg)).issubset(shape_orthogonal): # pragma: nocover
raise ValueError(
f"the shape of `{k}`, {na.shape(kwarg)}, "
f"should be a subset of the broadcasted shape of `*args` excluding `axis`, {shape_orthogonal}"
)
kwargs_broadcasted[k] = na.broadcast_to(kwarg, shape_orthogonal)
kwargs = kwargs_broadcasted

result = na.ScalarArray.empty(shape=shape_orthogonal, dtype=object)

for index in na.ndindex(shape_orthogonal):
if where[index]:
values_index = values[index].ndarray
edges_index = edges[index].ndarray if edges is not None else edges
kwargs_index = {k: kwargs[k][index].ndarray for k in kwargs}
result[index] = ax[index].ndarray.stairs(
values=values_index,
edges=edges_index,
**kwargs_index,
)

return result


@_implements(na.plt.imshow)
def plt_imshow(
X: na.AbstractScalarArray,
Expand Down
56 changes: 56 additions & 0 deletions named_arrays/_scalars/tests/test_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,62 @@ def test_histogram2d(
assert result.inputs.y.unit_normalized.is_equivalent(unit)
assert result.outputs.unit_normalized.is_equivalent(unit_weights)

@pytest.mark.parametrize(
argnames="edges",
argvalues=[
None,
na.linspace(-1, 1, axis="y", num=_num_y + 1)
]
)
@pytest.mark.parametrize(
argnames="ax",
argvalues=[
None,
na.plt.subplots(axis_rows="x", nrows=_num_x)[1],
],
)
@pytest.mark.parametrize(
argnames="axis",
argvalues=[
None,
"y",
],
)
class TestPltStairs:
def test_plt_stairs(
self,
array: na.AbstractScalar,
edges: None | na.AbstractScalar,
ax: None | matplotlib.axes.Axes,
axis: bool | str,
):
if edges is None:
args = (array, )
else:
args = (edges, array)

kwargs = dict(
ax=ax,
axis=axis,
)

if axis is None:
if len(na.shape(array)) != 1:
with pytest.raises(ValueError):
na.plt.stairs(*args, **kwargs)
return
else:
if axis not in na.shape(array):
with pytest.raises(ValueError):
na.plt.stairs(*args, **kwargs)
return

with astropy.visualization.quantity_support():
result = na.plt.stairs(*args, **kwargs)

assert isinstance(result, na.AbstractArray)
assert result.dtype == matplotlib.artist.Artist


class AbstractTestAbstractScalarArray(
AbstractTestAbstractScalar,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,116 @@ def plt_scatter(
return result


@_implements(na.plt.stairs)
def plt_stairs(
*args: na.AbstractScalar,
ax: None | matplotlib.axes.Axes = None,
axis: None | str = None,
where: bool | na.AbstractScalarArray = True,
**kwargs,
) -> na.UncertainScalarArray[
npt.NDArray[matplotlib.artist.Artist],
npt.NDArray[matplotlib.artist.Artist]
]:

if len(args) == 1:
edges = None
values, = args
elif len(args) == 2:
edges, values = args
else: # pragma: nocover
raise ValueError(
f"incorrect number of arguments, expected 1 or 2, got {len(args)}"
)

try:
values = uncertainties._normalize(values)
edges = uncertainties._normalize(edges) if edges is not None else edges
where = uncertainties._normalize(where)
kwargs = {k: uncertainties._normalize(kwargs[k]) for k in kwargs}
except na.UncertainScalarTypeError: # pragma: nocover
return NotImplemented

if axis is None:
if len(values.shape) != 1:
raise ValueError(
f"if {axis=}, {values.shape=} should have only one element."
)
axis = next(iter(values.shape))
else:
if axis not in values.shape:
raise ValueError(
f"{axis=} must be an element of {values.shape}"
)

shape_values = na.shape(values)
shape_edges = na.shape(edges)
shape_args = na.broadcast_shapes(
shape_values,
{a: shape_edges[a] for a in shape_edges if a != axis},
)

shape = na.broadcast_shapes(na.shape(ax), shape_args)

values = na.broadcast_to(values, shape)

if edges is not None:
edges = na.broadcast_to(
array=edges,
shape={a: shape[a] + 1 if a == axis else shape[a] for a in shape},
)

axis_distribution = values.axis_distribution
dshape_edges = na.shape(edges.distribution) if edges is not None else dict()
shape_distribution = na.broadcast_shapes(
na.shape(values.distribution),
{a: dshape_edges[a] for a in dshape_edges if a != axis},
)

if axis_distribution in shape_distribution:
num_distribution = shape_distribution[axis_distribution]
else:
num_distribution = 1

if num_distribution == 0:
alpha = 1
else:
alpha = max(1 / num_distribution, 1/255)
if "alpha" in kwargs:
kwargs["alpha"] *= alpha
else:
kwargs["alpha"] = na.UncertainScalarArray(1, alpha)

if "color" not in kwargs:
color_cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']
shape_orthogonal = {a: shape[a] for a in shape if a != axis}
color = na.ScalarArray.empty(shape=shape_orthogonal, dtype=object)
for i, index in enumerate(color.ndindex()):
color[index] = color_cycle[i % len(color_cycle)]
kwargs["color"] = uncertainties._normalize(color)

result = na.UncertainScalarArray(
nominal=na.plt.stairs(
edges.nominal if edges is not None else edges,
values.nominal,
ax=ax,
axis=axis,
where=where.nominal,
**{k: kwargs[k].nominal for k in kwargs}
),
distribution=na.plt.stairs(
edges.distribution if edges is not None else edges,
values.distribution,
ax=ax,
axis=axis,
where=where.distribution,
**{k: kwargs[k].distribution for k in kwargs}
)
)

return result


@_implements(na.jacobian)
def jacobian(
function: Callable[[na.AbstractScalar], na.AbstractScalar],
Expand Down
89 changes: 89 additions & 0 deletions named_arrays/plt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"plot",
"fill",
"scatter",
"stairs",
"imshow",
"pcolormesh",
"pcolormovie",
Expand Down Expand Up @@ -280,6 +281,94 @@ def fill(
)


def stairs(
*args: na.AbstractArray,
ax: None | matplotlib.axes.Axes | na.ScalarArray[npt.NDArray[matplotlib.axes.Axes]] = None,
axis: None | str = None,
where: bool | na.AbstractScalar = True,
**kwargs,
) -> na.ScalarArray[npt.NDArray[None | matplotlib.artist.Artist]]:
"""
A thin wrapper around :meth:`matplotlib.axes.Axes.stairs` for named arrays.
The main difference of this function from :func:`matplotlib.pyplot.stairs`
is the addition of the ``axis`` parameter indicating along which axis the
lines should be connected.
Another difference is that this function swaps the order of `values`
and `edges` so that the signature matches :func:`plot`.
Parameters
----------
args
Either a single instance of :class:`AbstractScalar` representing the step heights,
or a pair of instances of :class:`AbstractScalar` representing the edges and heights.
ax
The instances of :class:`matplotlib.axes.Axes` to use.
If :obj:`None`, calls :func:`matplotlib.pyplot.gca` to get the current axes.
If an instance of :class:`named_arrays.ScalarArray`, ``ax.shape`` should
be a subset of the broadcasted shape of ``*args``.
axis
The name of the axis that the plot lines should be connected along.
If :obj:`None`, the broadcasted shape of ``args`` should have only one element,
otherwise a :class:`ValueError` is raised.
where
A boolean array that selects which elements to plot
kwargs
Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.stairs`.
These can be instances of :class:`named_arrays.AbstractArray`.
Examples
--------
Plot a single scalar
.. jupyter-execute::
import numpy as np
import matplotlib.pyplot as plt
import named_arrays as na
x = na.linspace(0, 2 * np.pi, axis="x", num=101)
a = na.linspace(0, 2 * np.pi, axis="x", num=100, centers=True)
y = np.sin(a)
plt.figure();
na.plt.stairs(x, y);
Plot an array of scalars
.. jupyter-execute::
z = na.linspace(0, np.pi, axis="z", num=5)
y = np.sin(a - z)
plt.figure();
na.plt.stairs(x, y, axis="x");
Plot an uncertain scalar
.. jupyter-execute::
ua = na.NormalUncertainScalarArray(a, width=0.2)
uy = np.sin(ua)
plt.figure();
na.plt.stairs(x, uy);
"""
return na._named_array_function(
stairs,
*args,
ax=ax,
axis=axis,
where=where,
**kwargs,
)


def scatter(
*args: na.AbstractScalar,
s: None | na.AbstractScalarArray = None,
Expand Down

0 comments on commit 8c706e8

Please sign in to comment.