From b2fe95e0c8d8052dca0dc1485ec29c1090a47df0 Mon Sep 17 00:00:00 2001 From: Roy Smart Date: Sat, 12 Oct 2024 14:29:53 -0600 Subject: [PATCH] Changed behavior of `named_arrays.plt.plot()` to broadcast against the input Axes. (#83) --- named_arrays/_scalars/scalar_named_array_functions.py | 11 +++++++---- .../uncertainties_named_array_functions.py | 11 +++++++---- named_arrays/tests/test_core.py | 9 +++++---- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/named_arrays/_scalars/scalar_named_array_functions.py b/named_arrays/_scalars/scalar_named_array_functions.py index 6744add..e60c6b6 100644 --- a/named_arrays/_scalars/scalar_named_array_functions.py +++ b/named_arrays/_scalars/scalar_named_array_functions.py @@ -459,14 +459,17 @@ def plt_plot_like( ax = plt.gca() ax = na.as_named_array(ax) - shape = na.shape_broadcasted(*args) + shape_args = na.shape_broadcasted(*args) + + shape = na.broadcast_shapes(ax.shape, shape_args) if axis is None: - if len(shape) != 1: + if len(shape_args) != 1: raise ValueError( - f"if `axis` is `None`, the broadcasted shape of `*args`, {shape}, should have one element" + f"if `axis` is `None`, the broadcasted shape of `*args`, " + f"{shape_args}, should have one element" ) - axis = next(iter(shape)) + axis = next(iter(shape_args)) shape_orthogonal = {a: shape[a] for a in shape if a != axis} diff --git a/named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py b/named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py index 9709fd1..24fdb1b 100644 --- a/named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py +++ b/named_arrays/_scalars/uncertainties/uncertainties_named_array_functions.py @@ -312,14 +312,17 @@ def plt_plot_like( except na.UncertainScalarTypeError: return NotImplemented - shape = na.shape_broadcasted(*args) + shape_args = na.shape_broadcasted(*args) + + shape = na.broadcast_shapes(na.shape(ax), shape_args) if axis is None: - if len(shape) != 1: + if len(shape_args) != 1: raise ValueError( - f"if `axis` is `None`, the broadcasted shape of `*args`, {shape}, should have one element" + f"if `axis` is `None`, the broadcasted shape of `*args`, " + f"{shape_args}, should have one element" ) - axis = next(iter(shape)) + axis = next(iter(shape_args)) args = tuple(arg.broadcast_to(shape) for arg in args) diff --git a/named_arrays/tests/test_core.py b/named_arrays/tests/test_core.py index 7c915de..0f509be 100644 --- a/named_arrays/tests/test_core.py +++ b/named_arrays/tests/test_core.py @@ -1263,7 +1263,6 @@ def test_interp( argnames="ax", argvalues=[ np._NoValue, - plt.subplots()[1], na.plt.subplots(axis_rows="x", nrows=num_x)[1], ] ) @@ -1295,21 +1294,23 @@ def test_plt_plot_like( if alpha is not np._NoValue: kwargs["alpha"] = alpha - shape = na.shape_broadcasted(*args) + shape_args = na.shape_broadcasted(*args) + + shape = na.broadcast_shapes(na.shape(ax), shape_args) axis_normalized = axis if axis_normalized is np._NoValue: axis_normalized = None if axis_normalized is None: - if len(shape) != 1: + if len(shape_args) != 1: with pytest.raises( expected_exception=ValueError, match="if `axis` is `None`, the broadcasted shape of .* should have one element" ): func(*args, **kwargs) return - axis_normalized = next(iter(shape)) + axis_normalized = next(iter(shape_args)) shape_orthogonal = {a: shape[a] for a in shape if a != axis_normalized}