From ea1adda7e832b7e78e888a11c035179ef887dbdc Mon Sep 17 00:00:00 2001 From: Nicholas Reinicke Date: Mon, 17 Jul 2023 14:09:14 -0600 Subject: [PATCH] return any added road attributes with shortest path --- mappymatch/maps/igraph/igraph_map.py | 9 ++++++++- mappymatch/maps/nx/nx_map.py | 12 ++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/mappymatch/maps/igraph/igraph_map.py b/mappymatch/maps/igraph/igraph_map.py index 18302ac..6fd2496 100644 --- a/mappymatch/maps/igraph/igraph_map.py +++ b/mappymatch/maps/igraph/igraph_map.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union import igraph as ig import networkx as nx @@ -102,6 +102,9 @@ def __init__(self, graph: ig.Graph): self._node_id_name = node_id_name self._edge_id_name = edge_id_name + # store the names of any additional added attributes + self._additional_attribute_names: Set[str] = set() + self._build_rtree() # build mapping from mappymatch road id to igraph edge id @@ -139,6 +142,9 @@ def _build_road( metadata[self._dist_weight] = edge_data.get(self._dist_weight) metadata[self._time_weight] = edge_data.get(self._time_weight) + for attr in self._additional_attribute_names: + metadata[attr] = edge_data.get(attr) + road = Road( RoadId(source_node_id, target_node_id, road_key), edge_data[self._geom_key], @@ -208,6 +214,7 @@ def set_road_attributes(self, attributes: Dict[RoadId, Dict[str, Any]]): if edge_id is None: raise ValueError(f"Road id {road_id} not found in graph") for attr, val in attrs.items(): + self._additional_attribute_names.add(attr) if attr == self._geom_key: geom_updated = True self.g.es[edge_id][attr] = val diff --git a/mappymatch/maps/nx/nx_map.py b/mappymatch/maps/nx/nx_map.py index e23526a..1d8618b 100644 --- a/mappymatch/maps/nx/nx_map.py +++ b/mappymatch/maps/nx/nx_map.py @@ -2,7 +2,7 @@ import json from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union import networkx as nx import numpy as np @@ -72,6 +72,8 @@ def __init__(self, graph: nx.MultiDiGraph): self._metadata_key = metadata_key self._crs_key = crs_key + self._addtional_attribute_names: Set[str] = set() + self._build_rtree() def _has_road_id(self, road_id: RoadId) -> bool: @@ -98,6 +100,9 @@ def _build_road( metadata[self._dist_weight] = edge_data.get(self._dist_weight) metadata[self._time_weight] = edge_data.get(self._time_weight) + for attr_name in self._addtional_attribute_names: + metadata[attr_name] = edge_data.get(attr_name) + road = Road( road_id, edge_data[self._geom_key], @@ -107,7 +112,6 @@ def _build_road( return road def _build_rtree(self): - idx = rt.index.Index() for i, gtuple in enumerate(self.g.edges(data=True, keys=True)): u, v, k, d = gtuple @@ -167,6 +171,10 @@ def set_road_attributes(self, attributes: Dict[RoadId, Dict[str, Any]]): Returns: None """ + for attrs in attributes.values(): + for attr_name in attrs.keys(): + self._addtional_attribute_names.add(attr_name) + nx.set_edge_attributes(self.g, attributes) self._build_rtree()