diff --git a/pysheds/_sgrid.py b/pysheds/_sgrid.py index e0db269..f46865a 100644 --- a/pysheds/_sgrid.py +++ b/pysheds/_sgrid.py @@ -1,7 +1,8 @@ from heapq import heappop, heappush, heapify import math import numpy as np -from numba import njit, prange +from functools import wraps +from numba import njit, prange, from_dtype from numba.types import float64, int64, uint32, uint16, uint8, boolean, UniTuple, Tuple, List, DictType, void from numba.typed import typedlist @@ -1919,11 +1920,29 @@ def count(start=0, step=1): yield n n += step + +def pfwrapper(func): + # Implemenation detail of priority-flood algorithm + # Needed to define the types used in priority queue + @wraps(func) + def _wrapper(dem, mask, *args): + # Tuple elements: + # 0: dem data type (for elevation priority) + # 1: int64 for insertion index (to maintain total ordering) + # 2: int64 for row index + # 3: int64 for col index + tuple_type = Tuple([from_dtype(dem.dtype), int64, int64, int64]) + return func(dem, mask, tuple_type, *args) + return _wrapper + + +@pfwrapper @njit(cache=True) def _priority_flood(dem, dem_mask, tuple_type): open_cells = typedlist.List.empty_list(tuple_type) # Priority queue pits = typedlist.List.empty_list(tuple_type) # FIFO queue closed_cells = dem_mask.copy() + isertn = count() # Push the edges onto priority queue y, x = dem.shape @@ -1931,22 +1950,22 @@ def _priority_flood(dem, dem_mask, tuple_type): edge = _left(dem_mask)[:-1] for row, col in zip(count(), edge): if col >= 0: - open_cells.append((dem[row, col], row, col)) + open_cells.append((dem[row, col], next(isertn), row, col)) closed_cells[row, col] = True edge = _bottom(dem_mask)[:-1] for row, col in zip(edge, count()): if row >= 0: - open_cells.append((dem[row, col], row, col)) + open_cells.append((dem[row, col], next(isertn), row, col)) closed_cells[row, col] = True edge = np.flip(_right(dem_mask))[:-1] for row, col in zip(count(y - 1, step=-1), edge): if col >= 0: - open_cells.append((dem[row, col], row, col)) + open_cells.append((dem[row, col], next(isertn), row, col)) closed_cells[row, col] = True edge = np.flip(_top(dem_mask))[:-1] for row, col in zip(edge, count(x - 1, step=-1)): if row >= 0: - open_cells.append((dem[row, col], row, col)) + open_cells.append((dem[row, col], next(isertn), row, col)) closed_cells[row, col] = True heapify(open_cells) @@ -1956,10 +1975,10 @@ def _priority_flood(dem, dem_mask, tuple_type): pits_pos = 0 while open_cells or pits_pos < len(pits): if pits_pos < len(pits): - elv, i, j = pits[pits_pos] + elv, _, i, j = pits[pits_pos] pits_pos += 1 else: - elv, i, j = heappop(open_cells) + elv, _, i, j = heappop(open_cells) for n in range(8): row = i + row_offsets[n] @@ -1973,9 +1992,9 @@ def _priority_flood(dem, dem_mask, tuple_type): if dem[row, col] <= elv: dem[row, col] = elv - pits.append((elv, row, col)) + pits.append((elv, next(isertn), row, col)) else: - heappush(open_cells, (dem[row, col], row, col)) + heappush(open_cells, (dem[row, col], next(isertn), row, col)) closed_cells[row, col] = True # pits book-keeping