Skip to content

Commit

Permalink
Add typing to test_plot.py (#8889)
Browse files Browse the repository at this point in the history
* Update pyproject.toml

* Update test_plot.py

* Update test_plot.py

* Update test_plot.py

* Update test_plot.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update test_plot.py

* Update test_plot.py

* Update test_plot.py

* Update test_plot.py

* Update test_plot.py

* Update test_plot.py

* Update test_plot.py

* Update test_plot.py

* Update test_plot.py

* Update test_plot.py

* raise ValueError if too many dims are requested

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Illviljan and pre-commit-ci[bot] authored Apr 5, 2024
1 parent 56182f7 commit 5bcbf70
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 33 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ module = [
"xarray.tests.test_merge",
"xarray.tests.test_missing",
"xarray.tests.test_parallelcompat",
"xarray.tests.test_plot",
"xarray.tests.test_sparse",
"xarray.tests.test_ufuncs",
"xarray.tests.test_units",
Expand Down
72 changes: 40 additions & 32 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import contextlib
import inspect
import math
from collections.abc import Hashable
from collections.abc import Generator, Hashable
from copy import copy
from datetime import date, datetime, timedelta
from typing import Any, Callable, Literal
Expand Down Expand Up @@ -85,52 +85,54 @@ def test_all_figures_closed():

@pytest.mark.flaky
@pytest.mark.skip(reason="maybe flaky")
def text_in_fig():
def text_in_fig() -> set[str]:
"""
Return the set of all text in the figure
"""
return {t.get_text() for t in plt.gcf().findobj(mpl.text.Text)}
return {t.get_text() for t in plt.gcf().findobj(mpl.text.Text)} # type: ignore[attr-defined] # mpl error?


def find_possible_colorbars():
def find_possible_colorbars() -> list[mpl.collections.QuadMesh]:
# nb. this function also matches meshes from pcolormesh
return plt.gcf().findobj(mpl.collections.QuadMesh)
return plt.gcf().findobj(mpl.collections.QuadMesh) # type: ignore[return-value] # mpl error?


def substring_in_axes(substring, ax):
def substring_in_axes(substring: str, ax: mpl.axes.Axes) -> bool:
"""
Return True if a substring is found anywhere in an axes
"""
alltxt = {t.get_text() for t in ax.findobj(mpl.text.Text)}
alltxt: set[str] = {t.get_text() for t in ax.findobj(mpl.text.Text)} # type: ignore[attr-defined] # mpl error?
for txt in alltxt:
if substring in txt:
return True
return False


def substring_not_in_axes(substring, ax):
def substring_not_in_axes(substring: str, ax: mpl.axes.Axes) -> bool:
"""
Return True if a substring is not found anywhere in an axes
"""
alltxt = {t.get_text() for t in ax.findobj(mpl.text.Text)}
alltxt: set[str] = {t.get_text() for t in ax.findobj(mpl.text.Text)} # type: ignore[attr-defined] # mpl error?
check = [(substring not in txt) for txt in alltxt]
return all(check)


def property_in_axes_text(property, property_str, target_txt, ax):
def property_in_axes_text(
property, property_str, target_txt, ax: mpl.axes.Axes
) -> bool:
"""
Return True if the specified text in an axes
has the property assigned to property_str
"""
alltxt = ax.findobj(mpl.text.Text)
alltxt: list[mpl.text.Text] = ax.findobj(mpl.text.Text) # type: ignore[assignment]
check = []
for t in alltxt:
if t.get_text() == target_txt:
check.append(plt.getp(t, property) == property_str)
return all(check)


def easy_array(shape, start=0, stop=1):
def easy_array(shape: tuple[int, ...], start: float = 0, stop: float = 1) -> np.ndarray:
"""
Make an array with desired shape using np.linspace
Expand All @@ -140,7 +142,7 @@ def easy_array(shape, start=0, stop=1):
return a.reshape(shape)


def get_colorbar_label(colorbar):
def get_colorbar_label(colorbar) -> str:
if colorbar.orientation == "vertical":
return colorbar.ax.get_ylabel()
else:
Expand All @@ -150,27 +152,27 @@ def get_colorbar_label(colorbar):
@requires_matplotlib
class PlotTestCase:
@pytest.fixture(autouse=True)
def setup(self):
def setup(self) -> Generator:
yield
# Remove all matplotlib figures
plt.close("all")

def pass_in_axis(self, plotmethod, subplot_kw=None):
def pass_in_axis(self, plotmethod, subplot_kw=None) -> None:
fig, axs = plt.subplots(ncols=2, subplot_kw=subplot_kw)
plotmethod(ax=axs[0])
assert axs[0].has_data()

@pytest.mark.slow
def imshow_called(self, plotmethod):
def imshow_called(self, plotmethod) -> bool:
plotmethod()
images = plt.gca().findobj(mpl.image.AxesImage)
return len(images) > 0

def contourf_called(self, plotmethod):
def contourf_called(self, plotmethod) -> bool:
plotmethod()

# Compatible with mpl before (PathCollection) and after (QuadContourSet) 3.8
def matchfunc(x):
def matchfunc(x) -> bool:
return isinstance(
x, (mpl.collections.PathCollection, mpl.contour.QuadContourSet)
)
Expand Down Expand Up @@ -1248,14 +1250,16 @@ def test_discrete_colormap_list_levels_and_vmin_or_vmax(self) -> None:
def test_discrete_colormap_provided_boundary_norm(self) -> None:
norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4)
primitive = self.darray.plot.contourf(norm=norm)
np.testing.assert_allclose(primitive.levels, norm.boundaries)
np.testing.assert_allclose(list(primitive.levels), norm.boundaries)

def test_discrete_colormap_provided_boundary_norm_matching_cmap_levels(
self,
) -> None:
norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4)
primitive = self.darray.plot.contourf(norm=norm)
assert primitive.colorbar.norm.Ncmap == primitive.colorbar.norm.N
cbar = primitive.colorbar
assert cbar is not None
assert cbar.norm.Ncmap == cbar.norm.N # type: ignore[attr-defined] # Exists, debatable if public though.


class Common2dMixin:
Expand Down Expand Up @@ -2532,7 +2536,7 @@ def test_default_labels(self) -> None:

# Leftmost column should have array name
for ax in g.axs[:, 0]:
assert substring_in_axes(self.darray.name, ax)
assert substring_in_axes(str(self.darray.name), ax)

def test_test_empty_cell(self) -> None:
g = (
Expand Down Expand Up @@ -2635,7 +2639,7 @@ def test_facetgrid(self) -> None:
(True, "continuous", False, True),
],
)
def test_add_guide(self, add_guide, hue_style, legend, colorbar):
def test_add_guide(self, add_guide, hue_style, legend, colorbar) -> None:
meta_data = _infer_meta_data(
self.ds,
x="x",
Expand Down Expand Up @@ -2811,7 +2815,7 @@ def test_bad_args(
add_legend: bool | None,
add_colorbar: bool | None,
error_type: type[Exception],
):
) -> None:
with pytest.raises(error_type):
self.ds.plot.scatter(
x=x, y=y, hue=hue, add_legend=add_legend, add_colorbar=add_colorbar
Expand Down Expand Up @@ -3011,20 +3015,22 @@ def test_ncaxis_notinstalled_line_plot(self) -> None:
@requires_matplotlib
class TestAxesKwargs:
@pytest.fixture(params=[1, 2, 3])
def data_array(self, request):
def data_array(self, request) -> DataArray:
"""
Return a simple DataArray
"""
dims = request.param
if dims == 1:
return DataArray(easy_array((10,)))
if dims == 2:
elif dims == 2:
return DataArray(easy_array((10, 3)))
if dims == 3:
elif dims == 3:
return DataArray(easy_array((10, 3, 2)))
else:
raise ValueError(f"No DataArray implemented for {dims=}.")

@pytest.fixture(params=[1, 2])
def data_array_logspaced(self, request):
def data_array_logspaced(self, request) -> DataArray:
"""
Return a simple DataArray with logspaced coordinates
"""
Expand All @@ -3033,12 +3039,14 @@ def data_array_logspaced(self, request):
return DataArray(
np.arange(7), dims=("x",), coords={"x": np.logspace(-3, 3, 7)}
)
if dims == 2:
elif dims == 2:
return DataArray(
np.arange(16).reshape(4, 4),
dims=("y", "x"),
coords={"x": np.logspace(-1, 2, 4), "y": np.logspace(-5, -1, 4)},
)
else:
raise ValueError(f"No DataArray implemented for {dims=}.")

@pytest.mark.parametrize("xincrease", [True, False])
def test_xincrease_kwarg(self, data_array, xincrease) -> None:
Expand Down Expand Up @@ -3146,16 +3154,16 @@ def test_facetgrid_single_contour() -> None:


@requires_matplotlib
def test_get_axis_raises():
def test_get_axis_raises() -> None:
# test get_axis raises an error if trying to do invalid things

# cannot provide both ax and figsize
with pytest.raises(ValueError, match="both `figsize` and `ax`"):
get_axis(figsize=[4, 4], size=None, aspect=None, ax="something")
get_axis(figsize=[4, 4], size=None, aspect=None, ax="something") # type: ignore[arg-type]

# cannot provide both ax and size
with pytest.raises(ValueError, match="both `size` and `ax`"):
get_axis(figsize=None, size=200, aspect=4 / 3, ax="something")
get_axis(figsize=None, size=200, aspect=4 / 3, ax="something") # type: ignore[arg-type]

# cannot provide both size and figsize
with pytest.raises(ValueError, match="both `figsize` and `size`"):
Expand All @@ -3167,7 +3175,7 @@ def test_get_axis_raises():

# cannot provide axis and subplot_kws
with pytest.raises(ValueError, match="cannot use subplot_kws with existing ax"):
get_axis(figsize=None, size=None, aspect=None, ax=1, something_else=5)
get_axis(figsize=None, size=None, aspect=None, ax=1, something_else=5) # type: ignore[arg-type]


@requires_matplotlib
Expand Down

0 comments on commit 5bcbf70

Please sign in to comment.