Skip to content

Commit

Permalink
Add wrapper function.
Browse files Browse the repository at this point in the history
The decorator hides an implementation detail of numba that the caller need not worry about.
  • Loading branch information
groutr committed Feb 14, 2024
1 parent 91176eb commit 04b061b
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions pysheds/_sgrid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from heapq import heappop, heappush, heapify
import math
import numpy as np
from numba import njit, prange
from functools import wraps
from numba import njit, prange, from_dtype
from numba.types import float64, int64, uint32, uint16, uint8, boolean, UniTuple, Tuple, List, DictType, void
from numba.typed import typedlist

Expand Down Expand Up @@ -1919,34 +1920,52 @@ def count(start=0, step=1):
yield n
n += step


def pfwrapper(func):
# Implemenation detail of priority-flood algorithm
# Needed to define the types used in priority queue
@wraps(func)
def _wrapper(dem, mask, *args):
# Tuple elements:
# 0: dem data type (for elevation priority)
# 1: int64 for insertion index (to maintain total ordering)
# 2: int64 for row index
# 3: int64 for col index
tuple_type = Tuple([from_dtype(dem.dtype), int64, int64, int64])
return func(dem, mask, tuple_type, *args)
return _wrapper


@pfwrapper
@njit(cache=True)
def _priority_flood(dem, dem_mask, tuple_type):
open_cells = typedlist.List.empty_list(tuple_type) # Priority queue
pits = typedlist.List.empty_list(tuple_type) # FIFO queue
closed_cells = dem_mask.copy()
isertn = count()

# Push the edges onto priority queue
y, x = dem.shape

edge = _left(dem_mask)[:-1]
for row, col in zip(count(), edge):
if col >= 0:
open_cells.append((dem[row, col], row, col))
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
edge = _bottom(dem_mask)[:-1]
for row, col in zip(edge, count()):
if row >= 0:
open_cells.append((dem[row, col], row, col))
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
edge = np.flip(_right(dem_mask))[:-1]
for row, col in zip(count(y - 1, step=-1), edge):
if col >= 0:
open_cells.append((dem[row, col], row, col))
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
edge = np.flip(_top(dem_mask))[:-1]
for row, col in zip(edge, count(x - 1, step=-1)):
if row >= 0:
open_cells.append((dem[row, col], row, col))
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
heapify(open_cells)

Expand All @@ -1956,10 +1975,10 @@ def _priority_flood(dem, dem_mask, tuple_type):
pits_pos = 0
while open_cells or pits_pos < len(pits):
if pits_pos < len(pits):
elv, i, j = pits[pits_pos]
elv, _, i, j = pits[pits_pos]
pits_pos += 1
else:
elv, i, j = heappop(open_cells)
elv, _, i, j = heappop(open_cells)

for n in range(8):
row = i + row_offsets[n]
Expand All @@ -1973,9 +1992,9 @@ def _priority_flood(dem, dem_mask, tuple_type):

if dem[row, col] <= elv:
dem[row, col] = elv
pits.append((elv, row, col))
pits.append((elv, next(isertn), row, col))
else:
heappush(open_cells, (dem[row, col], row, col))
heappush(open_cells, (dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True

# pits book-keeping
Expand Down

0 comments on commit 04b061b

Please sign in to comment.