Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Priority Flood algorithm to fill depressions #243

Merged
merged 4 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 150 additions & 2 deletions pysheds/_sgrid.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from heapq import heappop, heappush
from heapq import heappop, heappush, heapify
import math
import numpy as np
from numba import njit, prange
from functools import wraps
from numba import njit, prange, from_dtype
from numba.types import float64, int64, uint32, uint16, uint8, boolean, UniTuple, Tuple, List, DictType, void
from numba.typed import typedlist

# Functions for 'flowdir'

Expand Down Expand Up @@ -1856,3 +1858,149 @@ def _fill_pits_numba(dem, pit_indices):
adjustment = min(diff, adjustment)
pits_filled.flat[k] += (adjustment)
return pits_filled

@njit(boundscheck=True, cache=True)
def _first_true1d(arr, start=0, end=None, step=1, invert=False):
if end is None:
end = len(arr)

if invert:
for i in range(start, end, step):
if not arr[i]:
return i
else:
return -1
else:
for i in range(start, end, step):
if arr[i]:
return i
else:
return -1

@njit(parallel=True, cache=True)
def _top(mask):
nc = mask.shape[1]
rv = np.zeros(nc, dtype='int64')
for i in prange(nc):
rv[i] = _first_true1d(mask[:, i], invert=True)
return rv

@njit(parallel=True, cache=True)
def _bottom(mask):
nr, nc = mask.shape[0], mask.shape[1]
rv = np.zeros(nc, dtype='int64')
for i in prange(nc):
rv[i] = _first_true1d(mask[:, i], start=nr - 1, end=-1, step=-1, invert=True)
return rv

@njit(parallel=True, cache=True)
def _left(mask):
nr = mask.shape[0]
rv = np.zeros(nr, dtype='int64')
for i in prange(nr):
rv[i] = _first_true1d(mask[i, :], invert=True)
return rv

@njit(parallel=True, cache=True)
def _right(mask):
nr, nc = mask.shape[0], mask.shape[1]
rv = np.zeros(nr, dtype='int64')
for i in prange(nr):
rv[i] = _first_true1d(mask[i, :], start=nc - 1, end=-1, step=-1, invert=True)
return rv


@njit(cache=True)
def count(start=0, step=1):
# Numba accelerated count() from itertools
# count(10) --> 10 11 12 13 14 ...
# count(2.5, 0.5) --> 2.5 3.0 3.5 ...
n = start
while True:
yield n
n += step


def pfwrapper(func):
# Implemenation detail of priority-flood algorithm
# Needed to define the types used in priority queue
@wraps(func)
def _wrapper(dem, mask, *args):
# Tuple elements:
# 0: dem data type (for elevation priority)
# 1: int64 for insertion index (to maintain total ordering)
# 2: int64 for row index
# 3: int64 for col index
tuple_type = Tuple([from_dtype(dem.dtype), int64, int64, int64])
return func(dem, mask, tuple_type, *args)
return _wrapper


@pfwrapper
@njit(cache=True)
def _priority_flood(dem, dem_mask, tuple_type):
open_cells = typedlist.List.empty_list(tuple_type) # Priority queue
pits = typedlist.List.empty_list(tuple_type) # FIFO queue
closed_cells = dem_mask.copy()
isertn = count()

# Push the edges onto priority queue
y, x = dem.shape

edge = _left(dem_mask)[:-1]
for row, col in zip(count(), edge):
if col >= 0:
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
edge = _bottom(dem_mask)[:-1]
for row, col in zip(edge, count()):
if row >= 0:
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
edge = np.flip(_right(dem_mask))[:-1]
for row, col in zip(count(y - 1, step=-1), edge):
if col >= 0:
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
edge = np.flip(_top(dem_mask))[:-1]
for row, col in zip(edge, count(x - 1, step=-1)):
if row >= 0:
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
heapify(open_cells)

row_offsets = np.array([-1, -1, 0, 1, 1, 1, 0, -1])
col_offsets = np.array([0, 1, 1, 1, 0, -1, -1, -1])

pits_pos = 0
while open_cells or pits_pos < len(pits):
if pits_pos < len(pits):
elv, _, i, j = pits[pits_pos]
pits_pos += 1
else:
elv, _, i, j = heappop(open_cells)

for n in range(8):
row = i + row_offsets[n]
col = j + col_offsets[n]

if row < 0 or row >= y or col < 0 or col >= x:
continue

if dem_mask[row, col] or closed_cells[row, col]:
continue

if dem[row, col] <= elv:
dem[row, col] = elv
pits.append((elv, next(isertn), row, col))
else:
heappush(open_cells, (dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True

# pits book-keeping
if pits_pos == len(pits) and len(pits) > 1024:
# Queue is empty, lets clear it out
pits.clear()
pits_pos = 0

return dem
28 changes: 9 additions & 19 deletions pysheds/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import pandas as pd
import geojson
from affine import Affine
from numba.types import Tuple, int64
from numba import from_dtype

try:
import skimage.measure
import skimage.morphology
_HAS_SKIMAGE = True
except ModuleNotFoundError:
_HAS_SKIMAGE = False
Expand Down Expand Up @@ -2113,8 +2115,6 @@ def detect_depressions(self, dem, **kwargs):
depressions : Raster
Boolean Raster indicating locations of depressions.
"""
if not _HAS_SKIMAGE:
raise ImportError('detect_depressions requires skimage.morphology module')
input_overrides = {'dtype' : np.float64, 'nodata' : dem.nodata}
kwargs.update(input_overrides)
dem = self._input_handler(dem, **kwargs)
Expand Down Expand Up @@ -2148,23 +2148,13 @@ def fill_depressions(self, dem, nodata_out=np.nan, **kwargs):
Raster representing digital elevation data with multi-celled
depressions removed.
"""
if not _HAS_SKIMAGE:
raise ImportError('resolve_flats requires skimage.morphology module')
input_overrides = {'dtype' : np.float64, 'nodata' : dem.nodata}
kwargs.update(input_overrides)
dem = self._input_handler(dem, **kwargs)
dem_mask = self._get_nodata_cells(dem)
dem_mask[0, :] = True
dem_mask[-1, :] = True
dem_mask[:, 0] = True
dem_mask[:, -1] = True
# Make sure nothing flows to the nodata cells
seed = np.copy(dem)
seed[~dem_mask] = np.nanmax(dem)
dem_out = skimage.morphology.reconstruction(seed, dem, method='erosion')
dem_out = self._output_handler(data=dem_out, viewfinder=dem.viewfinder,
metadata=dem.metadata, nodata=nodata_out)
return dem_out
result = _self._priority_flood(dem, dem_mask)
dem_filled = self._output_handler(data=result,
viewfinder=dem.viewfinder,
metadata=dem.metadata,
nodata=dem.nodata)
return dem_filled

def detect_flats(self, dem, **kwargs):
"""
Expand Down
Loading