Skip to content

Commit 4d3c136

Browse files
feat: average the value of different branch point paths
1 parent 8341f3e commit 4d3c136

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

kimimaro/utility.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Dict, Union, List, Tuple
22

3+
from collections import defaultdict
34
import copy
45

56
import numpy as np
@@ -168,6 +169,9 @@ def cross_sectional_area_helper(skel, binimg, roi):
168169
areas = np.zeros([all_verts.shape[0]], dtype=np.float32)
169170
contacts = np.zeros([all_verts.shape[0]], dtype=np.uint8)
170171

172+
branch_pts = set(skel.branches())
173+
branch_pt_vals = defaultdict(list)
174+
171175
paths = skel.paths()
172176

173177
normal = np.array([1,0,0], dtype=np.float32)
@@ -190,19 +194,25 @@ def cross_sectional_area_helper(skel, binimg, roi):
190194
normal = normals[i,:]
191195
normal /= np.linalg.norm(normal)
192196

193-
for i, vert in enumerate(path):
197+
for i, vert in tqdm(enumerate(path), total=path.shape[0]):
194198
if np.any(vert < 0) or np.any(vert > shape):
195199
continue
196200

197201
idx = mapping[tuple(vert)]
198202
normal = normals[i]
199203

200-
if areas[idx] == 0 or (repair_contacts and contacts[idx] > 0):
201-
areas[idx], contacts[idx] = xs3d.cross_sectional_area(
204+
if areas[idx] == 0 or idx in branch_pts or (repair_contacts and contacts[idx] > 0):
205+
areas[idx], contact = xs3d.cross_sectional_area(
202206
binimg, vert,
203207
normal, anisotropy,
204208
return_contact=True,
205209
)
210+
if repair_contacts:
211+
contacts[idx] = contact
212+
else:
213+
contacts[idx] |= contact # accumulate for branch points
214+
if idx in branch_pts:
215+
branch_pt_vals[idx].append(areas[idx])
206216
if visualize_section_planes:
207217
img = xs3d.cross_section(
208218
binimg, vert,
@@ -214,6 +224,9 @@ def cross_sectional_area_helper(skel, binimg, roi):
214224
import microviewer
215225
microviewer.view(cross_sections, seg=True)
216226

227+
for idx, vals in branch_pt_vals.items():
228+
areas[idx] = sum(vals) / len(vals)
229+
217230
add_property(skel, prop)
218231

219232
skel.cross_sectional_area = areas

0 commit comments

Comments
 (0)