diff --git a/src/pymatching/sparse_blossom/driver/mwpm_decoding.cc b/src/pymatching/sparse_blossom/driver/mwpm_decoding.cc index 44fda6cb..6a00bc72 100644 --- a/src/pymatching/sparse_blossom/driver/mwpm_decoding.cc +++ b/src/pymatching/sparse_blossom/driver/mwpm_decoding.cc @@ -86,7 +86,9 @@ void process_timeline_until_completion(pm::Mwpm& mwpm, const std::vector= mwpm.flooder.graph.nodes.size()) throw std::invalid_argument("Detection event index too large"); - mwpm.create_detection_event(&mwpm.flooder.graph.nodes[detection]); + if (detection + 1 > mwpm.flooder.graph.is_user_graph_boundary_node.size() || + !mwpm.flooder.graph.is_user_graph_boundary_node[detection]) + mwpm.create_detection_event(&mwpm.flooder.graph.nodes[detection]); } } else { @@ -102,7 +104,9 @@ void process_timeline_until_completion(pm::Mwpm& mwpm, const std::vector mwpm.flooder.graph.is_user_graph_boundary_node.size() || + !mwpm.flooder.graph.is_user_graph_boundary_node[detection]) + mwpm.create_detection_event(&mwpm.flooder.graph.nodes[detection]); } else { // Unmark node mwpm.flooder.graph.nodes[detection].radius_of_arrival = 0; diff --git a/src/pymatching/sparse_blossom/driver/user_graph.cc b/src/pymatching/sparse_blossom/driver/user_graph.cc index 77e79040..564b3d4b 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.cc +++ b/src/pymatching/sparse_blossom/driver/user_graph.cc @@ -258,6 +258,12 @@ pm::MatchingGraph pm::UserGraph::to_matching_graph(pm::weight_int num_distinct_w matching_graph.add_boundary_edge(u, weight, observables); }); matching_graph.normalising_constant = normalising_constant; + if (boundary_nodes.size() > 0) { + matching_graph.is_user_graph_boundary_node.clear(); + matching_graph.is_user_graph_boundary_node.resize(nodes.size(), false); + for (auto& i : boundary_nodes) + matching_graph.is_user_graph_boundary_node[i] = true; + } return matching_graph; } diff --git a/src/pymatching/sparse_blossom/driver/user_graph.test.cc b/src/pymatching/sparse_blossom/driver/user_graph.test.cc index 32225347..1a2fee6b 100644 --- a/src/pymatching/sparse_blossom/driver/user_graph.test.cc +++ b/src/pymatching/sparse_blossom/driver/user_graph.test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "pymatching/sparse_blossom/driver/user_graph.h" +#include "pymatching/sparse_blossom/driver/mwpm_decoding.h" #include #include @@ -143,3 +144,25 @@ TEST(UserGraph, NodesAlongShortestPath) { ASSERT_EQ(nodes, nodes_expected); } } + +TEST(UserGraph, DecodeUserGraphDetectionEventOnBoundaryNode) { + { + pm::UserGraph graph; + graph.add_or_merge_edge(0, 1, {0}, 1.0, -1); + graph.add_or_merge_edge(1, 2, {1}, 1.0, -1); + graph.set_boundary({2}); + auto& mwpm = graph.get_mwpm(); + pm::ExtendedMatchingResult res(mwpm.flooder.graph.num_observables); + pm::decode_detection_events(mwpm, {2}, res.obs_crossed.data(), res.weight); + } + + { + pm::UserGraph graph; + graph.add_or_merge_edge(0, 1, {0}, -1.0, -1); + graph.add_or_merge_edge(1, 2, {1}, 1.0, -1); + graph.set_boundary({2}); + auto& mwpm = graph.get_mwpm(); + pm::ExtendedMatchingResult res(mwpm.flooder.graph.num_observables); + pm::decode_detection_events(mwpm, {2}, res.obs_crossed.data(), res.weight); + } +} diff --git a/src/pymatching/sparse_blossom/flooder/graph.cc b/src/pymatching/sparse_blossom/flooder/graph.cc index 55e78ba4..b941b222 100644 --- a/src/pymatching/sparse_blossom/flooder/graph.cc +++ b/src/pymatching/sparse_blossom/flooder/graph.cc @@ -104,6 +104,7 @@ MatchingGraph::MatchingGraph(MatchingGraph&& graph) noexcept : nodes(std::move(graph.nodes)), negative_weight_detection_events_set(std::move(graph.negative_weight_detection_events_set)), negative_weight_observables_set(std::move(graph.negative_weight_observables_set)), + is_user_graph_boundary_node(std::move(graph.is_user_graph_boundary_node)), negative_weight_sum(graph.negative_weight_sum), num_nodes(graph.num_nodes), num_observables(graph.num_observables), diff --git a/src/pymatching/sparse_blossom/flooder/graph.h b/src/pymatching/sparse_blossom/flooder/graph.h index 4db0b3b1..3cf4093b 100644 --- a/src/pymatching/sparse_blossom/flooder/graph.h +++ b/src/pymatching/sparse_blossom/flooder/graph.h @@ -38,6 +38,11 @@ class MatchingGraph { std::set negative_weight_observables_set; /// The sum of the negative edge weights. This number is negative, rather than the absolute value. pm::total_weight_int negative_weight_sum; + /// is_user_graph_boundary_node is only filled if MatchingGraph is constructed from a UserGraph + /// is_user_graph_boundary_node[i] is true if i is a boudary node in the UserGraph + /// This vector is used to check that a detection event is not added to a node marked as a boundary node + /// in the UserGraph (as such an event cannot be matched, and will raise an error). + std::vector is_user_graph_boundary_node; size_t num_nodes; size_t num_observables; /// This is the normalising constant that the edge weights were multiplied by when converting from floats to diff --git a/tests/matching/decode_test.py b/tests/matching/decode_test.py index 38ad2a66..2ae29cbc 100644 --- a/tests/matching/decode_test.py +++ b/tests/matching/decode_test.py @@ -205,3 +205,16 @@ def test_decode_wrong_syndrome_type_raises_type_error(): m = Matching() m.add_edge(0, 1) m.decode([0, "A"]) + + +def test_syndrome_on_boundary_nodes(): + m = Matching() + m.add_edge(0, 1, fault_ids={0}) + m.add_edge(1, 2, fault_ids={1}) + m.add_edge(2, 3, fault_ids={2}) + m.add_edge(3, 4, fault_ids={3}) + m.set_boundary_nodes({3, 4}) + m.decode([0, 0, 0, 1, 0]) + m.decode([0, 0, 0, 0, 1]) + m.decode([0, 0, 0, 1, 1]) + m.decode([1, 0, 1, 0, 1])