Skip to content

Commit ffac52b

Browse files
test: simple unit tests for moving average
1 parent 5b04d99 commit ffac52b

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

automated_test.py

+31
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import kimimaro.intake
88
import kimimaro.skeletontricks
9+
from kimimaro.utility import moving_average
910

1011
def test_empty_image():
1112
labels = np.zeros( (256, 256, 256), dtype=bool)
@@ -469,7 +470,37 @@ def test_cross_sectional_area():
469470
assert np.all(skel.cross_sectional_area == 9)
470471

471472

473+
def test_moving_average():
472474

475+
data = np.array([])
476+
assert np.all(moving_average(data, 1) == data)
477+
assert np.all(moving_average(data, 2) == data)
478+
479+
data = np.array([1,1,1,1,1,1,1,1,1,1,1])
480+
assert np.all(moving_average(data, 1) == data)
481+
482+
data = np.array([1,1,1,1,1,1,1,1,1,1,1,1])
483+
assert np.all(moving_average(data, 1) == data)
484+
485+
data = np.array([1,1,1,1,1,10,1,1,1,1,1])
486+
assert np.all(moving_average(data, 1) == data)
487+
488+
data = np.array([1,1,1,1,1,1,1,1,1,1,1])
489+
assert np.all(moving_average(data, 2) == data)
490+
491+
data = np.array([0,1,1,1,1,1,1,1,1,1,0])
492+
ans = np.array([
493+
0,0.5,1,1,1,1,1,1,1,1,0.5
494+
])
495+
assert np.all(moving_average(data, 2) == ans)
496+
497+
data = np.array([0,1,1,1,1,1,1,1,1,1,0])
498+
ans = np.array([
499+
1/3,1/3,2/3,1,1,1,1,1,1,1,2/3
500+
])
501+
res = moving_average(data, 3)
502+
assert np.all(res == ans)
503+
assert len(ans) == len(data)
473504

474505

475506

kimimaro/utility.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,17 @@ def moving_average(a:np.ndarray, n:int) -> np.ndarray:
194194
raise ValueError(f"Window size ({n}), must be >= 1.")
195195
elif n == 1:
196196
return a
197-
mirror = (n - 1) / 2
198-
extra = 0
199-
if mirror != int(mirror):
200-
extra = 1
201-
mirror = int(mirror)
202-
a = np.pad(a, [[mirror, mirror+extra],[0,0]], mode="symmetric")
197+
198+
if len(a) == 0:
199+
return a
200+
201+
if a.ndim == 2:
202+
a = np.pad(a, [[n, n],[0,0]], mode="symmetric")
203+
else:
204+
a = np.pad(a, [n, n], mode="symmetric")
205+
203206
ret = np.cumsum(a, dtype=float, axis=0)
204-
ret[n:] = ret[n:] - ret[:-n]
205-
return ret[n - 1:] / n
207+
ret = (ret[n:] - ret[:-n])[:-n]
208+
ret /= float(n)
209+
return ret
206210

0 commit comments

Comments
 (0)