diff --git a/gerrychain/partition/partition.py b/gerrychain/partition/partition.py index b9204c1f..14a0cf2f 100644 --- a/gerrychain/partition/partition.py +++ b/gerrychain/partition/partition.py @@ -31,9 +31,11 @@ class Partition: '_cache' ) + default_updaters = {"cut_edges": cut_edges} + def __init__( self, graph=None, assignment=None, updaters=None, parent=None, flips=None, - use_cut_edges=True + use_default_updaters=True ): """ :param graph: Underlying graph. @@ -41,18 +43,17 @@ def __init__( :param updaters: Dictionary of functions to track data about the partition. The keys are stored as attributes on the partition class, which the functions compute. - :param use_cut_edges: If `False`, do not include `cut_edges` updater by default - and do not calculate edge flows. + :param use_default_updaters: If `False`, do not include default updaters. """ if parent is None: - self._first_time(graph, assignment, updaters, use_cut_edges) + self._first_time(graph, assignment, updaters, use_default_updaters) else: self._from_parent(parent, flips) self._cache = dict() self.subgraphs = SubgraphView(self.graph, self.parts) - def _first_time(self, graph, assignment, updaters, use_cut_edges): + def _first_time(self, graph, assignment, updaters, use_default_updaters): if isinstance(graph, Graph): self.graph = FrozenGraph(graph) elif isinstance(graph, networkx.Graph): @@ -71,8 +72,8 @@ def _first_time(self, graph, assignment, updaters, use_cut_edges): if updaters is None: updaters = dict() - if use_cut_edges: - self.updaters = {"cut_edges": cut_edges} + if use_default_updaters: + self.updaters = self.default_updaters else: self.updaters = {} diff --git a/tests/partition/test_partition.py b/tests/partition/test_partition.py index 1a20d6f7..17f57bd6 100644 --- a/tests/partition/test_partition.py +++ b/tests/partition/test_partition.py @@ -153,3 +153,25 @@ def test_partition_has_default_updaters(example_partition): def test_partition_has_keys(example_partition): assert "cut_edges" in set(example_partition.keys()) + + +def test_geographic_partition_has_keys(example_geographic_partition): + keys = set(example_geographic_partition.updaters.keys()) + + assert "perimeter" in keys + assert "exterior_boundaries" in keys + assert "interior_boundaries" in keys + assert "boundary_nodes" in keys + assert "cut_edges" in keys + assert "area" in keys + assert "cut_edges_by_part" in keys + + +def test_partition_has_default_updaters(example_geographic_partition): + assert hasattr(example_geographic_partition, "perimeter") + assert hasattr(example_geographic_partition, "exterior_boundaries") + assert hasattr(example_geographic_partition, "interior_boundaries") + assert hasattr(example_geographic_partition, "boundary_nodes") + assert hasattr(example_geographic_partition, "cut_edges") + assert hasattr(example_geographic_partition, "area") + assert hasattr(example_geographic_partition, "cut_edges_by_part")