Skip to content

Commit

Permalink
Merge pull request #879 from snoyer/filter_by-property
Browse files Browse the repository at this point in the history
allow to filter and group by property
  • Loading branch information
gumyr authored Jan 29, 2025
2 parents 45dc04c + bdd11a9 commit 9268f31
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
19 changes: 16 additions & 3 deletions src/build123d/topology/shape_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2351,7 +2351,7 @@ def faces(self) -> ShapeList[Face]:

def filter_by(
self,
filter_by: ShapePredicate | Axis | Plane | GeomType,
filter_by: ShapePredicate | Axis | Plane | GeomType | property,
reverse: bool = False,
tolerance: float = 1e-5,
) -> ShapeList[T]:
Expand Down Expand Up @@ -2446,6 +2446,11 @@ def pred(shape: Shape):
# convert input to callable predicate
if callable(filter_by):
predicate = filter_by
elif isinstance(filter_by, property):

def predicate(obj):
return filter_by.__get__(obj)

elif isinstance(filter_by, Axis):
predicate = axis_parallel_predicate(filter_by, tolerance=tolerance)
elif isinstance(filter_by, Plane):
Expand Down Expand Up @@ -2524,7 +2529,9 @@ def filter_by_position(

def group_by(
self,
group_by: Callable[[Shape], K] | Axis | Edge | Wire | SortBy = Axis.Z,
group_by: (
Callable[[Shape], K] | Axis | Edge | Wire | SortBy | property
) = Axis.Z,
reverse=False,
tol_digits=6,
) -> GroupBy[T, K]:
Expand Down Expand Up @@ -2594,6 +2601,9 @@ def key_f(obj):
elif callable(group_by):
key_f = group_by

elif isinstance(group_by, property):
key_f = group_by.__get__

else:
raise ValueError(f"Unsupported group_by function: {group_by}")

Expand Down Expand Up @@ -2625,7 +2635,7 @@ def solids(self) -> ShapeList[Solid]:

def sort_by(
self,
sort_by: Axis | Callable[[T], K] | Edge | Wire | SortBy = Axis.Z,
sort_by: Axis | Callable[[T], K] | Edge | Wire | SortBy | property = Axis.Z,
reverse: bool = False,
) -> ShapeList[T]:
"""sort by
Expand All @@ -2651,6 +2661,9 @@ def sort_by(
# If a callable is provided, use it directly as the key
objects = sorted(self, key=sort_by, reverse=reverse)

elif isinstance(sort_by, property):
objects = sorted(self, key=sort_by.__get__, reverse=reverse)

elif isinstance(sort_by, Axis):
if sort_by.wrapped is None:
raise ValueError("Cannot sort by an empty axis")
Expand Down
25 changes: 25 additions & 0 deletions tests/test_direct_api/test_shape_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ def test_sort_by_lambda(self):
self.assertAlmostEqual(smallest.area, math.pi * 1**2, 5)
self.assertAlmostEqual(largest.area, math.pi * 2**2, 5)

def test_sort_by_property(self):
box1 = Box(1, 1, 1)
box2 = Box(2, 2, 2)
box3 = Box(3, 3, 3)
unsorted_boxes = ShapeList([box2, box3, box1])
assert unsorted_boxes.sort_by(Solid.volume) == [box1, box2, box3]
assert unsorted_boxes.sort_by(Solid.volume, reverse=True) == [box3, box2, box1]

def test_sort_by_invalid(self):
with self.assertRaises(ValueError):
Solid.make_box(1, 1, 1).faces().sort_by(">Z")
Expand Down Expand Up @@ -119,6 +127,12 @@ def test_filter_by_callable_predicate(self):
self.assertEqual(len(shapelist.filter_by(lambda s: s.label == "A")), 2)
self.assertEqual(len(shapelist.filter_by(lambda s: s.label == "B")), 1)

def test_filter_by_property(self):
box1 = Box(2, 2, 2)
box2 = Box(2, 2, 2).translate((1, 1, 1))
assert len((box1 + box2).edges().filter_by(Edge.is_interior)) == 6
assert len((box1 - box2).edges().filter_by(Edge.is_interior)) == 3

def test_first_last(self):
vertices = (
Solid.make_box(1, 1, 1).vertices().sort_by(Axis((0, 0, 0), (1, 1, 1)))
Expand Down Expand Up @@ -187,6 +201,17 @@ def test_group_by_callable_predicate(self):

self.assertEqual([len(group) for group in result], [1, 3, 2])

def test_group_by_property(self):
box1 = Box(2, 2, 2)
box2 = Box(2, 2, 2).translate((1, 1, 1))
g1 = (box1 + box2).edges().group_by(Edge.is_interior)
assert len(g1.group(True)) == 6
assert len(g1.group(False)) == 24

g2 = (box1 - box2).edges().group_by(Edge.is_interior)
assert len(g2.group(True)) == 3
assert len(g2.group(False)) == 18

def test_group_by_retrieve_groups(self):
boxesA = [Solid.make_box(1, 1, 1) for _ in range(3)]
boxesB = [Solid.make_box(1, 1, 1) for _ in range(2)]
Expand Down

0 comments on commit 9268f31

Please sign in to comment.