Skip to content

Commit

Permalink
Changed behavior of named_arrays.plt.plot() to broadcast against th…
Browse files Browse the repository at this point in the history
…e input Axes. (#83)
  • Loading branch information
byrdie authored Oct 12, 2024
1 parent 160a4cd commit b2fe95e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
11 changes: 7 additions & 4 deletions named_arrays/_scalars/scalar_named_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,17 @@ def plt_plot_like(
ax = plt.gca()
ax = na.as_named_array(ax)

shape = na.shape_broadcasted(*args)
shape_args = na.shape_broadcasted(*args)

shape = na.broadcast_shapes(ax.shape, shape_args)

if axis is None:
if len(shape) != 1:
if len(shape_args) != 1:
raise ValueError(
f"if `axis` is `None`, the broadcasted shape of `*args`, {shape}, should have one element"
f"if `axis` is `None`, the broadcasted shape of `*args`, "
f"{shape_args}, should have one element"
)
axis = next(iter(shape))
axis = next(iter(shape_args))

shape_orthogonal = {a: shape[a] for a in shape if a != axis}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,17 @@ def plt_plot_like(
except na.UncertainScalarTypeError:
return NotImplemented

shape = na.shape_broadcasted(*args)
shape_args = na.shape_broadcasted(*args)

shape = na.broadcast_shapes(na.shape(ax), shape_args)

if axis is None:
if len(shape) != 1:
if len(shape_args) != 1:
raise ValueError(
f"if `axis` is `None`, the broadcasted shape of `*args`, {shape}, should have one element"
f"if `axis` is `None`, the broadcasted shape of `*args`, "
f"{shape_args}, should have one element"
)
axis = next(iter(shape))
axis = next(iter(shape_args))

args = tuple(arg.broadcast_to(shape) for arg in args)

Expand Down
9 changes: 5 additions & 4 deletions named_arrays/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,7 +1263,6 @@ def test_interp(
argnames="ax",
argvalues=[
np._NoValue,
plt.subplots()[1],
na.plt.subplots(axis_rows="x", nrows=num_x)[1],
]
)
Expand Down Expand Up @@ -1295,21 +1294,23 @@ def test_plt_plot_like(
if alpha is not np._NoValue:
kwargs["alpha"] = alpha

shape = na.shape_broadcasted(*args)
shape_args = na.shape_broadcasted(*args)

shape = na.broadcast_shapes(na.shape(ax), shape_args)

axis_normalized = axis
if axis_normalized is np._NoValue:
axis_normalized = None

if axis_normalized is None:
if len(shape) != 1:
if len(shape_args) != 1:
with pytest.raises(
expected_exception=ValueError,
match="if `axis` is `None`, the broadcasted shape of .* should have one element"
):
func(*args, **kwargs)
return
axis_normalized = next(iter(shape))
axis_normalized = next(iter(shape_args))

shape_orthogonal = {a: shape[a] for a in shape if a != axis_normalized}

Expand Down

0 comments on commit b2fe95e

Please sign in to comment.