diff --git a/pdm.lock b/pdm.lock index 29b2f6c..d8efec6 100644 --- a/pdm.lock +++ b/pdm.lock @@ -6,7 +6,7 @@ groups = ["default", "style", "util", "docs-examples", "test", "viz", "docs"] cross_platform = true static_urls = false lock_version = "4.3" -content_hash = "sha256:661044947a135c2fd35f044894672183f425c2c57b03ca9c93abece419035faf" +content_hash = "sha256:1702a1b450c50d15a33409f28b57105b302e959c210b6ac02029f9fbfa14fed5" [[package]] name = "affine" @@ -2604,32 +2604,17 @@ files = [ [[package]] name = "pystac-client" -version = "0.7.5" -requires_python = ">=3.8" -summary = "Python library for working with SpatioTemporal Asset Catalog (STAC) APIs." -dependencies = [ - "pystac[validation]>=1.8.2", - "python-dateutil>=2.8.2", - "requests>=2.28.2", -] -files = [ - {file = "pystac-client-0.7.5.tar.gz", hash = "sha256:4b0ed0f7177dfc6e394aeb3ecf1236364f315b1d38c107afbcbbef17c2f7db8b"}, - {file = "pystac_client-0.7.5-py3-none-any.whl", hash = "sha256:b07c21f0bfbe7ea19cd23e535406ee08ee604b8ff8d9afcee666c0b1fe017dc4"}, -] - -[[package]] -name = "pystac" -version = "1.8.3" -extras = ["validation"] +version = "0.6.1" requires_python = ">=3.8" -summary = "Python library for working with the SpatioTemporal Asset Catalog (STAC) specification" +summary = "Python library for working with Spatiotemporal Asset Catalog (STAC)." dependencies = [ - "jsonschema<4.18,>=4.0.1", - "pystac==1.8.3", + "pystac>=1.7.0", + "python-dateutil>=2.7.0", + "requests>=2.27.1", ] files = [ - {file = "pystac-1.8.3-py3-none-any.whl", hash = "sha256:91805520b0b5386db84aae5296dc6d4fb6754410c481d0a00a8afedc3b4c75d5"}, - {file = "pystac-1.8.3.tar.gz", hash = "sha256:3fd0464bfeb7e99893b24c8d683dd3d046c48b2e53ed65d0a8a704f1281f1ed1"}, + {file = "pystac-client-0.6.1.tar.gz", hash = "sha256:1981537ad0fd167b08790eb3f41e7c2788438f461125b42b47bc934eaf1adcb1"}, + {file = "pystac_client-0.6.1-py3-none-any.whl", hash = "sha256:124d81bd9653b3e12c7ff244bf0dad420cadeaf86ab394dfdc804958ff723fcd"}, ] [[package]] @@ -2654,15 +2639,15 @@ files = [ [[package]] name = "python-dateutil" -version = "2.8.2" -requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +version = "2.7.5" +requires_python = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" summary = "Extensions to the standard Python datetime module" dependencies = [ "six>=1.5", ] files = [ - {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, - {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, + {file = "python-dateutil-2.7.5.tar.gz", hash = "sha256:88f9287c0174266bb0d8cedd395cfba9c58e87e5ad86b2ce58859bc11be3cf02"}, + {file = "python_dateutil-2.7.5-py2.py3-none-any.whl", hash = "sha256:063df5763652e21de43de7d9e00ccf239f953a832941e37be541614732cdfc93"}, ] [[package]] @@ -2958,6 +2943,18 @@ files = [ {file = "s3transfer-0.6.0.tar.gz", hash = "sha256:2ed07d3866f523cc561bf4a00fc5535827981b117dd7876f036b0c1aca42c947"}, ] +[[package]] +name = "sat-stac" +version = "0.4.1" +summary = "A Python library for working with Spatio-Temporal Asset Catalogs (STAC)" +dependencies = [ + "python-dateutil~=2.7.5", + "requests>=2.19.1", +] +files = [ + {file = "sat-stac-0.4.1.tar.gz", hash = "sha256:a06bb3491ee49497262d228e9c8aed4d68d89e7dff140e2439f199dd92864f2c"}, +] + [[package]] name = "scipy" version = "1.6.1" diff --git a/pyproject.toml b/pyproject.toml index c1e772c..80b8fd2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ style = [ test = [ "hypothesis<7.0.0,>=6.35.0", "pytest<7.0.0,>=6.2.5", + "sat-stac>=0.4.1", ] util = [ "py-spy", diff --git a/stackstac/__init__.py b/stackstac/__init__.py index b179390..f9946d2 100644 --- a/stackstac/__init__.py +++ b/stackstac/__init__.py @@ -1,6 +1,6 @@ from .rio_env import LayeredEnv from .rio_reader import DEFAULT_GDAL_ENV, MULTITHREADED_DRIVER_ALLOWLIST -from .stack import stack +from .stack import stack, DEFAULT_RETRY_ERRORS, DEFAULT_ERRORS_AS_NODATA from .ops import mosaic from .geom_utils import reproject_array, array_bounds, array_epsg, xyztile_of_array @@ -13,7 +13,6 @@ msg = _traceback.format_exc() def _missing_imports(*args, **kwargs): - raise ImportError( "Optional dependencies for map visualization are missing.\n" "Please re-install stackstac with the `viz` extra:\n" @@ -34,6 +33,8 @@ def _missing_imports(*args, **kwargs): __all__ = [ "LayeredEnv", "DEFAULT_GDAL_ENV", + "DEFAULT_RETRY_ERRORS", + "DEFAULT_ERRORS_AS_NODATA", "MULTITHREADED_DRIVER_ALLOWLIST", "stack", "show", diff --git a/stackstac/nodata_reader.py b/stackstac/nodata_reader.py deleted file mode 100644 index 8aab7f1..0000000 --- a/stackstac/nodata_reader.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Tuple, Type, Union -import re - -import numpy as np -from rasterio.windows import Window - -from .reader_protocol import Reader - -State = Tuple[np.dtype, Union[int, float]] - - -class NodataReader: - "Reader that returns a constant (nodata) value for all reads" - scale_offset = (1.0, 0.0) - - def __init__( - self, - *, - dtype: np.dtype, - fill_value: Union[int, float], - **kwargs, - ) -> None: - self.dtype = dtype - self.fill_value = fill_value - - def read(self, window: Window, **kwargs) -> np.ndarray: - return nodata_for_window(window, self.fill_value, self.dtype) - - def close(self) -> None: - pass - - def __getstate__(self) -> State: - return (self.dtype, self.fill_value) - - def __setstate__(self, state: State) -> None: - self.dtype, self.fill_value = state - - -def nodata_for_window(window: Window, fill_value: Union[int, float], dtype: np.dtype): - return np.full((window.height, window.width), fill_value, dtype) - - -def exception_matches(e: Exception, patterns: Tuple[Exception, ...]) -> bool: - """ - Whether an exception matches one of the pattern exceptions - - Parameters - ---------- - e: - The exception to check - patterns: - Instances of an Exception type to catch, where ``str(exception_pattern)`` - is a regex pattern to match against ``str(e)``. - """ - e_type = type(e) - e_msg = str(e) - for pattern in patterns: - if issubclass(e_type, type(pattern)): - if re.match(str(pattern), e_msg): - return True - return False - - -# Type assertion -_: Type[Reader] = NodataReader diff --git a/stackstac/reader_protocol.py b/stackstac/reader_protocol.py index 4b38fae..1e3854f 100644 --- a/stackstac/reader_protocol.py +++ b/stackstac/reader_protocol.py @@ -25,6 +25,7 @@ class Reader(Pickleable, Protocol): """ Protocol for a thread-safe, lazily-loaded object for reading data from a single-band STAC asset. """ + url: str def __init__( self, @@ -36,7 +37,6 @@ def __init__( fill_value: Union[int, float], scale_offset: Tuple[Union[int, float], Union[int, float]], gdal_env: Optional[LayeredEnv], - errors_as_nodata: Tuple[Exception, ...] = (), ) -> None: """ Construct the Dataset *without* fetching any data. @@ -58,14 +58,6 @@ def __init__( gdal_env: A `~.LayeredEnv` of GDAL configuration options to use while opening and reading datasets. If None (default), `~.DEFAULT_GDAL_ENV` is used. - errors_as_nodata: - Exception patterns to ignore when opening datasets or reading data. - Exceptions matching the pattern will be logged as warnings, and just - produce nodata (``fill_value``). - - The exception patterns should be instances of an Exception type to catch, - where ``str(exception_pattern)`` is a regex pattern to match against - ``str(raised_exception)``. """ # TODO colormaps? @@ -113,6 +105,7 @@ class FakeReader: def __init__(self, *, dtype: np.dtype, **kwargs) -> None: self.dtype = dtype + self.url = "fake" def read(self, window: Window, **kwargs) -> np.ndarray: return np.random.random((window.height, window.width)).astype(self.dtype) diff --git a/stackstac/rio_reader.py b/stackstac/rio_reader.py index 61e44a1..85a0c54 100644 --- a/stackstac/rio_reader.py +++ b/stackstac/rio_reader.py @@ -2,7 +2,6 @@ import logging import threading -import warnings from typing import TYPE_CHECKING, Optional, Protocol, Tuple, Type, TypedDict, Union import numpy as np @@ -13,7 +12,6 @@ from .timer import time from .reader_protocol import Reader from .raster_spec import RasterSpec -from .nodata_reader import NodataReader, exception_matches, nodata_for_window if TYPE_CHECKING: from rasterio.enums import Resampling @@ -42,7 +40,7 @@ def _curthread(): open=dict( GDAL_DISABLE_READDIR_ON_OPEN="EMPTY_DIR", # ^ stop GDAL from requesting `.aux` and `.msk` files from the bucket (speeds up `open` time a lot) - VSI_CACHE=True + VSI_CACHE=True, # ^ cache HTTP requests for opening datasets. This is critical for `ThreadLocalRioDataset`, # which re-opens the same URL many times---having the request cached makes subsequent `open`s # in different threads snappy. @@ -283,7 +281,6 @@ class PickleState(TypedDict): fill_value: Union[int, float] scale_offset: Tuple[Union[int, float], Union[int, float]] gdal_env: Optional[LayeredEnv] - errors_as_nodata: Tuple[Exception, ...] class AutoParallelRioReader: @@ -306,7 +303,6 @@ def __init__( fill_value: Union[int, float], scale_offset: Tuple[Union[int, float], Union[int, float]], gdal_env: Optional[LayeredEnv] = None, - errors_as_nodata: Tuple[Exception, ...] = (), ) -> None: self.url = url self.spec = spec @@ -315,7 +311,6 @@ def __init__( self.fill_value = fill_value self.scale_offset = scale_offset self.gdal_env = gdal_env or DEFAULT_GDAL_ENV - self.errors_as_nodata = errors_as_nodata self._dataset: Optional[ThreadsafeRioDataset] = None self._dataset_lock = threading.Lock() @@ -323,17 +318,7 @@ def __init__( def _open(self) -> ThreadsafeRioDataset: with self.gdal_env.open: with time(f"Initial read for {self.url!r} on {_curthread()}: {{t}}"): - try: - ds = SelfCleaningDatasetReader(self.url, sharing=False) - except Exception as e: - msg = f"Error opening {self.url!r}: {e!r}" - if exception_matches(e, self.errors_as_nodata): - warnings.warn(msg) - return NodataReader( - dtype=self.dtype, fill_value=self.fill_value - ) - - raise RuntimeError(msg) from e + ds = SelfCleaningDatasetReader(self.url, sharing=False) if ds.count != 1: ds.close() raise RuntimeError( @@ -375,7 +360,7 @@ def _open(self) -> ThreadsafeRioDataset: return SingleThreadedRioDataset(self.gdal_env, ds, vrt=vrt) @property - def dataset(self): + def dataset(self) -> ThreadsafeRioDataset: with self._dataset_lock: if self._dataset is None: self._dataset = self._open() @@ -383,22 +368,14 @@ def dataset(self): def read(self, window: Window, **kwargs) -> np.ndarray: reader = self.dataset - try: - result = reader.read( - window=window, - out_dtype=self.dtype, - masked=True, - # ^ NOTE: we always do a masked array, so we can safely apply scales and offsets - # without potentially altering pixels that should have been the ``fill_value`` - **kwargs, - ) - except Exception as e: - msg = f"Error reading {window} from {self.url!r}: {e!r}" - if exception_matches(e, self.errors_as_nodata): - warnings.warn(msg) - return nodata_for_window(window, self.fill_value, self.dtype) - - raise RuntimeError(msg) from e + result = reader.read( + window=window, + out_dtype=self.dtype, + masked=True, + # ^ NOTE: we always do a masked array, so we can safely apply scales and offsets + # without potentially altering pixels that should have been the ``fill_value`` + **kwargs, + ) # When the GeoTIFF doesn't have a nodata value, and we're using a VRT, pixels # outside the dataset don't get properly masked (they're just 0). Using `add_alpha` @@ -409,7 +386,9 @@ def read(self, window: Window, **kwargs) -> np.ndarray: elif result.shape[0] == 1: result = result[0] else: - raise RuntimeError(f"Unexpected shape {result.shape}, expected exactly 1 band.") + raise RuntimeError( + f"Unexpected shape {result.shape}, expected exactly 1 band." + ) scale, offset = self.scale_offset @@ -419,9 +398,9 @@ def read(self, window: Window, **kwargs) -> np.ndarray: result += offset result = np.ma.filled(result, fill_value=self.fill_value) - assert np.issubdtype(result.dtype, self.dtype), ( - f"Expected result array with dtype {self.dtype!r}, got {result.dtype!r}" - ) + assert np.issubdtype( + result.dtype, self.dtype + ), f"Expected result array with dtype {self.dtype!r}, got {result.dtype!r}" return result def close(self) -> None: @@ -451,7 +430,6 @@ def __getstate__( "fill_value": self.fill_value, "scale_offset": self.scale_offset, "gdal_env": self.gdal_env, - "errors_as_nodata": self.errors_as_nodata, } def __setstate__( diff --git a/stackstac/stack.py b/stackstac/stack.py index dc85fd9..b65d5c3 100644 --- a/stackstac/stack.py +++ b/stackstac/stack.py @@ -23,6 +23,15 @@ from .to_dask import items_to_dask, ChunksParam +DEFAULT_RETRY_ERRORS = ( + RasterioIOError(r"HTTP response code: (400|429|5\d\d)"), + RasterioIOError("Read or write failed"), + RasterioIOError("not recognized as a supported file format"), +) + +DEFAULT_ERRORS_AS_NODATA = (RasterioIOError("HTTP response code: 404"),) + + def stack( items: Union[ ItemCollectionIsh, ItemIsh, Sequence[PystacItem], Sequence[SatstacItem] @@ -45,9 +54,9 @@ def stack( properties: Union[bool, str, Sequence[str]] = True, band_coords: bool = True, gdal_env: Optional[LayeredEnv] = None, - errors_as_nodata: Tuple[Exception, ...] = ( - RasterioIOError("HTTP response code: 404"), - ), + retry_errors: Tuple[Exception, ...] = DEFAULT_RETRY_ERRORS, + retries: int = 3, + errors_as_nodata: Tuple[Exception, ...] = DEFAULT_ERRORS_AS_NODATA, reader: Type[Reader] = AutoParallelRioReader, ) -> xr.DataArray: """ @@ -247,16 +256,30 @@ def stack( Advanced use: a `~.LayeredEnv` of GDAL configuration options to use while opening and reading datasets. If None (default), `~.DEFAULT_GDAL_ENV` is used. See ``rio_reader.py`` for notes on why these default options were chosen. + retry_errors: + Exception patterns to retry when opening datasets or reading data. + Errors matching the pattern will be logged as warnings, and retried after a delay. + + The exception patterns should be instances of an `Exception` type to catch, + where ``str(exception_pattern)`` is a regex pattern to match against + ``str(raised_exception)``. For example, ``RasterioIOError("HTTP response code: 404")`` + (the default). Or ``IOError(r"HTTP response code: 5\\d\\d")``, to catch any 5xx HTTP error. + Or ``Exception(".*")`` to catch absolutely anything (that one's probably a bad idea). + + By default, all HTTP 500 errors are retried, as well as ``Read or write failed`` and + ``not recognized as a supported file format``. + retries: + How many times to retry errors in `retry_errors`. Default: 3. Set to 0 to disable retries. + + Note that retries are counted per chunk, not per asset. errors_as_nodata: Exception patterns to ignore when opening datasets or reading data. Exceptions matching the pattern will be logged as warnings, and just produce nodata (``fill_value``). - The exception patterns should be instances of an Exception type to catch, - where ``str(exception_pattern)`` is a regex pattern to match against - ``str(raised_exception)``. For example, ``RasterioIOError("HTTP response code: 404")`` - (the default). Or ``IOError(r"HTTP response code: 4\\d\\d")``, to catch any 4xx HTTP error. - Or ``Exception(".*")`` to catch absolutely anything (that one's probably a bad idea). + The exception patterns should be just like `retry_errors`. + + By default, only HTTP 404 errors are ignored as nodata. reader: Advanced use: the `~.Reader` type to use. Currently there is only one real reader type: `~.AutoParallelRioReader`. However, there's also `~.FakeReader` (which doesn't read data at all, @@ -309,6 +332,8 @@ def stack( reader=reader, gdal_env=gdal_env, errors_as_nodata=errors_as_nodata, + retry_errors=retry_errors, + retries=retries, ) return xr.DataArray( diff --git a/stackstac/tests/test_to_dask.py b/stackstac/tests/test_to_dask.py index b814961..74f2bfc 100644 --- a/stackstac/tests/test_to_dask.py +++ b/stackstac/tests/test_to_dask.py @@ -1,10 +1,12 @@ from __future__ import annotations +import contextlib from threading import Lock -from typing import ClassVar +from typing import ClassVar, Literal from hypothesis import given, settings, strategies as st import hypothesis.extra.numpy as st_np import numpy as np +import pytest from rasterio import windows import dask.core import dask.threaded @@ -16,6 +18,7 @@ ChunksParam, items_to_dask, normalize_chunks, + read_handle_errors, ) from stackstac.testing import strategies as st_stc @@ -36,9 +39,9 @@ def asset_tables( np.array( [ # Encode the (i, j) index in the table in the URL - [("fake://0/0", [0, 0, 2, 2]), ("fake://0/1", [0, 0, 2, 2])], - [("fake://1/0", [0, 3, 2, 5]), ("fake://1/1", [10, 13, 12, 15])], - [("fake://2/0", [1, 3, 2, 6]), ("fake://2/1", [1, 3, 2, 7])], + [("fake://0/0", [0, 0, 2, 2], (1, 0)), ("fake://0/1", [0, 0, 2, 2], (1, 0))], + [("fake://1/0", [0, 3, 2, 5], (1, 0)), ("fake://1/1", [10, 13, 12, 15], (1, 0))], + [("fake://2/0", [1, 3, 2, 6], (1, 0)), ("fake://2/1", [1, 3, 2, 7], (1, 0))], [(None, None), (None, None)], ], dtype=ASSET_TABLE_DT, @@ -61,7 +64,8 @@ def asset_tables( for (i, j), bounds in np.ndenumerate(bounds_arr): if bounds: # Encode the (i, j) index in the table in the URL - asset_table[i, j] = (f"fake://{i}/{j}", bounds) + # TODO generate scale and offset + asset_table[i, j] = (f"fake://{i}/{j}", bounds, (1, 0)) return asset_table @@ -138,6 +142,7 @@ def __init__( i, j = map(int, url[7:].split("/")) self.full_data = results[i, j] self.window = asset_windows[url] + self.url = url assert spec == spec_ assert dtype == dtype_ @@ -200,3 +205,170 @@ def test_normalize_chunks( assert all(x >= 1 for t in chunks for x in t) if isinstance(chunksize, int) or isinstance(chunks, tuple) and len(chunks) == 2: assert numblocks[:2] == shape[:2] + + +def test_error_handling(): + asset_table = np.array( + [ + [("00", (0, 0, 1, 1), (1, 0)), ("01", (0, 0, 1, 1), (1, 0))], + [("10", (0, 0, 1, 1), (1, 0)), ("11", (0, 0, 1, 1), (1, 0))], + [("20", (0, 0, 1, 1), (1, 0)), ("21", (0, 0, 1, 1), (1, 0))], + ], + dtype=ASSET_TABLE_DT, + ) + + spec_ = RasterSpec(4326, (0, 0, 1, 1), (0.5, 0.5)) + # Errors that will occur for each asset. (Testing this per asset + # is a lot easier than per chunk, but should generalize the same.) + errors = np.full(asset_table.shape, fill_value=None, dtype=object) + errors[0, 0] = [RuntimeError("retry")] + errors[0, 1] = [RuntimeError("retry me"), RuntimeError("nodata")] + errors[1, 1] = [RuntimeError("nodata")] + errors[2, 0] = [ + ValueError("HTTP error: 503"), + ValueError("HTTP error: 500"), + RuntimeError("retry me"), + ] + + class ErrorReader: + def __init__( + self, + *, + url: str, + # spec: RasterSpec, + # dtype: np.dtype, + # fill_value: int | float, + **kwargs, + ) -> None: + self.url = url + i, j = int(url[0]), int(url[1]) + self.errors = errors[i, j] + self.reads = 0 + + def read(self, window: windows.Window) -> np.ndarray: + self.reads += 1 + if self.errors and self.reads <= len(self.errors): + raise self.errors[self.reads - 1] + else: + # NOTE: we use a chunksize of -1 for simplicity, so every asset is a full chunk. + # This also means we don't have to lock `self.reads`, because there will only ever + # be one dask task per ErrorReader. + return np.full(spec_.shape, fill_value=1) + + def close(self) -> None: + pass + + def __getstate__(self) -> dict: + return self.__dict__ + + def __setstate__(self, state): + self.__init__(**state) + + kwargs = dict( + asset_table=asset_table, + spec=spec_, + chunksize=-1, + dtype=np.dtype(int), + fill_value=0, + reader=ErrorReader, + errors_as_nodata=(RuntimeError("nodata"),), + retry_errors=(RuntimeError("retry"), ValueError(r"HTTP error: 5\d\d")), + ) + + arr = items_to_dask( + **kwargs, # type: ignore + retries=0, + ) + with pytest.raises(RuntimeError, match="retry"): + arr.compute(scheduler="sync") + + arr = items_to_dask( + **kwargs, # type: ignore + retries=1, + ) + with pytest.warns(UserWarning, match="Error reading") as record: + with pytest.raises(RuntimeError, match="HTTP error: 500"): + arr.compute(scheduler="sync") + assert len(record) == 5 + + arr = items_to_dask( + **kwargs, # type: ignore + retries=3, + ) + + expected = np.full(arr.shape, fill_value=1, dtype=arr.dtype) + # where we have nodata errors + expected[0, 1] = 0 + expected[1, 1] = 0 + + with pytest.warns(UserWarning, match="Error reading") as record: + assert_eq(arr, expected) + assert len(record) == 7 + + +@pytest.mark.parametrize( + "errors, retries, attempts, expected", + [ + ([], 0, 1, "data"), + ([], 1, 1, "data"), + ([RuntimeError("retry")], 1, 2, "data"), + ([RuntimeError("nodata")], 0, 1, "nodata"), + ([RuntimeError("foo"), RuntimeError("bar")], 0, 1, RuntimeError("foo")), + ([RuntimeError("foo"), RuntimeError("bar")], 1, 1, RuntimeError("foo")), + ([RuntimeError("retry"), RuntimeError("retry")], 1, 2, RuntimeError("retry")), + ([RuntimeError("retry"), RuntimeError("retry")], 2, 3, "data"), + ([RuntimeError("retry"), RuntimeError("nodata")], 2, 2, "nodata"), + ([RuntimeError("nodata"), RuntimeError("retry")], 2, 1, "nodata"), + ], +) +def test_read_handle_errors( + errors: tuple[Exception, ...], + retries: int, + attempts: int, + expected: Exception | Literal["nodata"] | Literal["data"], + recwarn, +): + class ErrorReader: + def __init__(self, errors=(), **kwargs) -> None: + self.url = "test" + self.errors = list(errors) + self.reads = 0 + + def read(self, window: windows.Window) -> np.ndarray: + self.reads += 1 + if self.errors: + raise self.errors.pop(0) + return np.array([1]) + + def close(self) -> None: + pass + + def __getstate__(self) -> dict: + return self.__dict__ + + def __setstate__(self, state): + self.__init__(**state) + + window = windows.Window(0, 0, 1, 1) # type: ignore + reader = ErrorReader(errors=errors) + + ctx = ( + pytest.raises(type(expected), match=str(expected)) + if isinstance(expected, Exception) + else contextlib.nullcontext() + ) + with ctx: + result = read_handle_errors( + reader, + window, + retry_errors=(RuntimeError("retry"),), + errors_as_nodata=(RuntimeError("nodata"),), + retries=retries, + ) + if expected == "nodata": + assert result is None + elif expected == "data": + assert isinstance(result, np.ndarray) + + assert reader.reads == attempts + assert len(recwarn) == (attempts if expected == "nodata" else attempts - 1) diff --git a/stackstac/to_dask.py b/stackstac/to_dask.py index 280196b..cf30a59 100644 --- a/stackstac/to_dask.py +++ b/stackstac/to_dask.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import Dict, Literal, Optional, Tuple, Type, Union +from typing import Dict, Literal, NamedTuple, Optional, Tuple, Type, Union +import re import warnings -from affine import Affine import dask import dask.array as da from dask.blockwise import blockwise @@ -32,9 +32,12 @@ def items_to_dask( reader: Type[Reader] = AutoParallelRioReader, gdal_env: Optional[LayeredEnv] = None, errors_as_nodata: Tuple[Exception, ...] = (), + retry_errors: Tuple[Exception, ...] = (), + retries: int = 0, ) -> da.Array: "Create a dask Array from an asset table" errors_as_nodata = errors_as_nodata or () # be sure it's not None + retry_errors = retry_errors or () # be sure it's not None if not np.can_cast(fill_value, dtype): raise ValueError( @@ -68,7 +71,7 @@ def items_to_dask( # don't match, which may not be behavior we can rely on. # (The actual content of the annotation is irrelevant here, just that there is one.) reader_table = asset_table_dask.map_blocks( - asset_table_to_reader_and_window, + asset_table_to_reader_entry, spec, resampling, dtype, @@ -76,6 +79,8 @@ def items_to_dask( rescale, gdal_env, errors_as_nodata, + retry_errors, + retries, reader, dtype=object, ) @@ -105,10 +110,15 @@ def items_to_dask( return rasters -ReaderTableEntry = Optional[Tuple[Reader, windows.Window]] +class ReaderTableEntry(NamedTuple): + reader: Reader + asset_window: windows.Window + errors_as_nodata: Tuple[Exception, ...] + retry_errors: Tuple[Exception, ...] + retries: int -def asset_table_to_reader_and_window( +def asset_table_to_reader_entry( asset_table: np.ndarray, spec: RasterSpec, resampling: Resampling, @@ -117,14 +127,16 @@ def asset_table_to_reader_and_window( rescale: bool, gdal_env: Optional[LayeredEnv], errors_as_nodata: Tuple[Exception, ...], + retry_errors: Tuple[Exception, ...], + retries: int, reader: Type[Reader], ) -> np.ndarray: """ "Open" an asset table by creating a `Reader` for each asset. This function converts the asset table (or chunks thereof) into an object array, - where each element contains a tuple of the `Reader` and `Window` for that asset, - or None if the element has no URL. + where each element contains the `ReaderTableEntry` for that asset, or None if the + element has no URL. """ reader_table = np.empty_like(asset_table, dtype=object) for index, asset_entry in np.ndenumerate(asset_table): @@ -137,18 +149,26 @@ def asset_table_to_reader_and_window( else: asset_scale_offset = (1, 0) - entry: ReaderTableEntry = ( - reader( - url=url, - spec=spec, - resampling=resampling, - dtype=dtype, - fill_value=fill_value, - scale_offset=asset_scale_offset, - gdal_env=gdal_env, - errors_as_nodata=errors_as_nodata, - ), - asset_window, + r = reader( + url=url, + spec=spec, + resampling=resampling, + dtype=dtype, + fill_value=fill_value, + scale_offset=asset_scale_offset, + gdal_env=gdal_env, + ) + + # NOTE: to minimize dask graph size, we put things that would be better + # suited as arguments for `fetch_raster_window` (like `retry_errors`) into + # the reader table instead. This saves pickling the same tuples of exceptions + # for every chunk (pickle isn't smart enough to avoid the duplication). + entry = ReaderTableEntry( + reader=r, + asset_window=asset_window, + errors_as_nodata=errors_as_nodata, + retry_errors=retry_errors, + retries=retries, ) reader_table[index] = entry return reader_table @@ -175,18 +195,26 @@ def fetch_raster_window( ) all_empty: bool = True - entry: ReaderTableEntry + entry: ReaderTableEntry | None for index, entry in np.ndenumerate(reader_table): if entry: - reader, asset_window = entry # Only read if the window we're fetching actually overlaps with the asset - if windows.intersect(current_window, asset_window): + if windows.intersect(current_window, entry.asset_window): # NOTE: when there are multiple assets, we _could_ parallelize these reads with our own threadpool. # However, that would probably increase memory usage, since the internal, thread-local GDAL datasets # would end up copied to even more threads. # TODO when the Reader won't be rescaling, support passing `output` to avoid the copy? - data = reader.read(current_window) + data = read_handle_errors( + entry.reader, + current_window, + entry.retry_errors, + entry.errors_as_nodata, + entry.retries, + ) + if data is None: + # A nodata error; just leave this empty in `output` + continue if all_empty: # Turn `output` from a broadcast-trick array to a real array, so it's writeable @@ -205,6 +233,59 @@ def fetch_raster_window( return output +def read_handle_errors( + reader: Reader, + window: windows.Window, + retry_errors: Tuple[Exception, ...], + errors_as_nodata: Tuple[Exception, ...], + retries: int, +) -> np.ndarray | None: + attempt = 0 + while True: + try: + return reader.read(window) + except Exception as e: + msg = f"Error reading {window} from {reader.url!r} (attempt {attempt} of {retries}): {e!r}" + if attempt < retries and exception_matches(e, retry_errors): + warnings.warn(msg) + attempt += 1 + # TODO: sleep + # TODO: is sleeping the best way to do this? in principle, we could let some other + # request happen while we're waiting (maybe something else from the reader table, + # or an entirely different dask task). secede/rejoin might be an option? but might + # also encourage higher memory use and make things more complicated. + # Also, whatever problem is happening may or may not be independent between assets... + # could be some internals of cloud storage that are correlated. + continue + + if exception_matches(e, errors_as_nodata): + warnings.warn("Ignoring as nodata: " + msg) + return None + + raise RuntimeError(msg) from e + + +def exception_matches(e: Exception, patterns: Tuple[Exception, ...]) -> bool: + """ + Whether an exception matches one of the pattern exceptions + + Parameters + ---------- + e: + The exception to check + patterns: + Instances of an Exception type to catch, where ``str(exception_pattern)`` + is a regex pattern to match against ``str(e)``. + """ + e_type = type(e) + e_msg = str(e) + for pattern in patterns: + if issubclass(e_type, type(pattern)): + if re.match(str(pattern), e_msg): + return True + return False + + def normalize_chunks( chunks: ChunksParam, shape: Tuple[int, int, int, int], dtype: np.dtype ) -> Tuple[Tuple[int, ...], Tuple[int, ...], Tuple[int, ...], Tuple[int, ...]]: