From e22ee660ce5acb37ed013742ff4ff293a5638dd2 Mon Sep 17 00:00:00 2001 From: "jason.regina" Date: Thu, 16 Jan 2025 17:09:38 +0000 Subject: [PATCH] add jit compiling and make lower level filter method --- .../hydrotools/events/baseflow/eckhardt.py | 20 ++++++++++--------- python/events/tests/test_baseflow.py | 14 +++++-------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/python/events/src/hydrotools/events/baseflow/eckhardt.py b/python/events/src/hydrotools/events/baseflow/eckhardt.py index 035e5126..99c82e8a 100644 --- a/python/events/src/hydrotools/events/baseflow/eckhardt.py +++ b/python/events/src/hydrotools/events/baseflow/eckhardt.py @@ -24,6 +24,7 @@ import numpy as np import numpy.typing as npt +from numba import jit, float64 def linear_recession_analysis( series: npt.ArrayLike, @@ -67,8 +68,9 @@ def maximum_baseflow_analysis( """ return 0.5 +@jit(float64[:](float64[:], float64, float64), nogil=True) def separate_baseflow( - series: npt.ArrayLike, + series: npt.NDArray, recession_constant: float, maximum_baseflow_index: float ) -> npt.NDArray: @@ -77,8 +79,8 @@ def separate_baseflow( Parameters ---------- - series: array-like, required - An array of streamflow values. Assumes first value in series is baseflow. + series: array-type, required + A numpy array of streamflow values. Assumes first value in series is baseflow. recession_constant: float, required Linear reservoir recession constant, a, from Eckhardt (2005, 2008). maximum_baseflow_index: float @@ -96,10 +98,10 @@ def separate_baseflow( # Instantiate baseflow series # Assume first value is baseflow - streamflow = np.asarray(series) - baseflow = np.empty(len(series)) - baseflow[0] = streamflow[0] - for i in range(1, len(series)): - baseflow[i] = A * baseflow[i-1] + B * streamflow[i] + baseflow = np.empty(series.size) + baseflow[0] = series[0] - return np.minimum(baseflow, streamflow) + # Apply filter and return result + for i in range(1, len(series)): + baseflow[i] = min(series[i], A * baseflow[i-1] + B * series[i]) + return baseflow diff --git a/python/events/tests/test_baseflow.py b/python/events/tests/test_baseflow.py index fd4842a1..9765582c 100644 --- a/python/events/tests/test_baseflow.py +++ b/python/events/tests/test_baseflow.py @@ -18,16 +18,12 @@ def test_maximum_baseflow_analysis(): def test_separate_baseflow(): rng = np.random.default_rng() - s = rng.normal(100.0, 10.0, 1000) + s = rng.normal(100.0, 10.0, 100) # Test numpy + from time import perf_counter + start = perf_counter() b = bf.separate_baseflow(s, 0.9, 0.5) - assert b[0] == s[0] - - # Test list - b = bf.separate_baseflow(s.tolist(), 0.9, 0.5) - assert b[0] == s[0] - - # Test pandas - b = bf.separate_baseflow(pd.Series(s), 0.9, 0.5) + end = perf_counter() + print(f"{end-start:.6f} s") assert b[0] == s[0]