diff --git a/src/planning/rrt.rs b/src/planning/rrt.rs index b5564af..7515e8b 100644 --- a/src/planning/rrt.rs +++ b/src/planning/rrt.rs @@ -24,6 +24,34 @@ use crate::tree::Distance; use crate::tree::Tree; use std::hash::Hash; +/// Attempts to randomly extend the tree in an arbitrary direction. +/// Return the new point and the nearest neighbor, if available. +/// Otherwise return None. +fn extend_tree( + tree: &Tree, + sample: &mut FS, + extend: &mut FE, + is_valid: &mut FV, +) -> Option<(T, T)> +where + T: Eq + Copy + Hash + Distance, + FS: FnMut() -> T, + FE: FnMut(&T, &T) -> T, + FV: FnMut(&T) -> bool, +{ + // Sample the grab the nearest point, and extend in that direction + let s = sample(); + let nearest = tree.nearest_neighbor(&s); + let new_point = extend(&nearest, &s); + + // If it is an invalid point try again + if !is_valid(&new_point) { + return None; + } + + Some((new_point, nearest.clone())) +} + /// Basic RRT implementation. /// /// Will attempt to compute a path using the RRT algorithm given the specified start pose @@ -38,7 +66,7 @@ use std::hash::Hash; /// - `success`: Determines whether or not we have reached the goal /// - `max_iterations`: Maximum number of random samples to attempt before the search fails /// -/// /// # Returns +/// # Returns /// Returns a `Result` containing either: /// - `Ok(Vec)`: A vector of points of type `T` representing the path from the start to a point /// satisfying the `success` condition, if such a path is found within the given number @@ -68,18 +96,14 @@ where let mut tree = Tree::new(start.clone()); for _ in 0..max_iterations { - // Sample the grab the nearest point, and extend in that direction - let s = sample(); - let nearest = tree.nearest(&s); - let new_point = extend(&nearest, &s); - - // If it is an invalid point try again - if !is_valid(&new_point) { - continue; - } + let (new_point, nearest) = match extend_tree(&tree, &mut sample, &mut extend, &mut is_valid) + { + Some((new_point, nearest)) => (new_point, nearest), + None => continue, + }; // Otherwise it's valid so add it to the tree - if let Ok(_) = tree.add_child(nearest, new_point) { + if let Ok(_) = tree.add_child(&nearest, new_point) { // We're good } else { // Then the child wasn't added for some reason so just try again diff --git a/src/tree.rs b/src/tree.rs index bcffc45..d42eaa6 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -29,29 +29,26 @@ use std::hash::Hash; /// /// Must be used with [Tree] since children are referenced by index in the [Tree]'s node vector. #[derive(Debug)] -struct TreeNode { +struct Node { + // The value of this node. value: T, - // Maintains a list of pointers to the children's location in the parent's vector - children: Vec, - // Location of the nodes parent, if available parent: Option, + + // Maintains a list of pointers to the children's location in the parent's vector. + // Using a vector to maintain order for tree traversals. + children: Vec, } -impl TreeNode { - fn new(val: T, parent: Option) -> Self { - TreeNode { - value: val, +impl Node { + fn new(value: T, parent: Option) -> Self { + Node { + value: value, parent: parent, children: Vec::new(), } } - - // Add a child to this node's children. Uniqueness must be enforced by the caller. - fn add_child(&mut self, child: usize) { - self.children.push(child); - } } /// Define a distance trait for tree node values. @@ -62,7 +59,7 @@ pub trait Distance { /// DFS Iterator for a [Tree] pub struct DepthFirstIterator<'a, T> where - T: 'a + Eq + Copy + Distance + Hash, + T: 'a + Eq + Clone + Distance + Hash, { tree: &'a Tree, stack: Vec, @@ -70,7 +67,7 @@ where impl<'a, T> DepthFirstIterator<'a, T> where - T: Eq + Copy + Distance + Hash, + T: Eq + Clone + Distance + Hash, { fn new(tree: &'a Tree) -> Self { let mut stack = Vec::new(); @@ -84,7 +81,7 @@ where impl<'a, T> Iterator for DepthFirstIterator<'a, T> where - T: Eq + Copy + Distance + Hash, + T: Eq + Clone + Distance + Hash, { type Item = &'a T; @@ -102,7 +99,8 @@ where /// Basic tree for use in search algorithms. /// -/// Provides functions for creating, growing, and finding the nearest neighbor to `T`. +/// Provides functions for creating, growing, finding the nearest neighbors to `T`, +/// and rewiring the based on cost are provided. /// Node values must be unique. /// /// TODO: Make this a KD Tree? @@ -110,30 +108,31 @@ where #[derive(Debug)] pub struct Tree where - T: Eq + Copy + Distance + Hash, + T: Eq + Clone + Distance + Hash, { - // Store nodes in a vector to support easy iteration and growth of the tree. - // The root will always be at idx 0. - nodes: Vec>, + // Detailed node data for the tree. + nodes: Vec>, - // Support constant time lookup of nodes by value + // Support constant time lookup of nodes data with a value - node index map. nodes_map: HashMap, } -impl Tree { - +impl Tree { /// Construct a new tree with the specified value as the root node. + /// + /// The node will take ownership of the provided value. pub fn new(val: T) -> Self { - let root_node = TreeNode::new(val, None); let mut nodes = Vec::new(); let mut nodes_map = HashMap::new(); + // Construct root node and add it to storage + let root_node = Node::new(val.clone(), None); nodes.push(root_node); nodes_map.insert(val, 0); Tree { - nodes: nodes, - nodes_map: nodes_map, + nodes, + nodes_map, } } @@ -143,7 +142,7 @@ impl Tree { /// /// If the parent is not found in the tree. /// If the child is already in the tree. - pub fn add_child(&mut self, parent: T, child: T) -> Result<(), String> { + pub fn add_child(&mut self, parent: &T, child: T) -> Result<(), String> { // Cannot duplicate children if self.nodes_map.contains_key(&child) { return Err("The child is already in the tree".to_string()); @@ -152,9 +151,9 @@ impl Tree { if let Some(&parent_idx) = self.nodes_map.get(&parent) { // Append the child node to the nodes vector and note the location in the map. let child_idx = self.nodes.len(); - self.nodes.push(TreeNode::new(child, Some(parent_idx))); + self.nodes.push(Node::new(child.clone(), Some(parent_idx))); self.nodes_map.insert(child, child_idx); - self.nodes[parent_idx].add_child(child_idx); + self.nodes[parent_idx].children.push(child_idx); } else { return Err("The parent cannot be found in the tree".to_string()); } @@ -168,8 +167,9 @@ impl Tree { } /// Returns the closest element to the specified value - pub fn nearest(&self, val: &T) -> T { - self.nodes + pub fn nearest_neighbor(&self, val: &T) -> &T { + &self + .nodes .iter() .min_by(|a, b| { let da = val.distance(&a.value); @@ -197,12 +197,12 @@ impl Tree { } // Build the path from end to beginning - let mut path: Vec = Vec::new(); + let mut path = Vec::new(); // Loop until you get to the root let mut cur_idx = Some(self.nodes_map[&end]); while let Some(idx) = cur_idx { - path.push(self.nodes[idx].value); + path.push(self.nodes[idx].value.clone()); cur_idx = self.nodes[idx].parent; } @@ -216,7 +216,7 @@ impl Tree { // Unit tests // -// Needed for distancing points on a plane +// Needed for distancing points on a line impl Distance for i32 { fn distance(&self, other: &Self) -> f64 { (self - other).abs().into() @@ -235,19 +235,19 @@ mod tests { assert_eq!(tree.nodes[0].value, 1); // Add a child and make sure everything is ok - assert!(tree.add_child(1, 2).is_ok()); + assert!(tree.add_child(&1, 2).is_ok()); assert_eq!(tree.size(), 2); // Make the tree bigger - assert!(tree.add_child(1, 3).is_ok()); - assert!(tree.add_child(2, 4).is_ok()); + assert!(tree.add_child(&1, 3).is_ok()); + assert!(tree.add_child(&2, 4).is_ok()); assert_eq!(tree.size(), 4); // Add an existing child and everything is not ok - assert!(tree.add_child(1, 2).is_err()); + assert!(tree.add_child(&1, 2).is_err()); // Add to a nonexistent parent and everything is not ok - assert!(tree.add_child(3, 2).is_err()); + assert!(tree.add_child(&3, 2).is_err()); } #[test] @@ -255,16 +255,16 @@ mod tests { // Construct tree with many nodes let mut tree: Tree = Tree::new(1); - assert!(tree.add_child(1, 2).is_ok()); - assert!(tree.add_child(1, 3).is_ok()); - assert!(tree.add_child(2, 4).is_ok()); - assert!(tree.add_child(2, 5).is_ok()); - assert!(tree.add_child(2, 6).is_ok()); + assert!(tree.add_child(&1, 2).is_ok()); + assert!(tree.add_child(&1, 3).is_ok()); + assert!(tree.add_child(&2, 4).is_ok()); + assert!(tree.add_child(&2, 5).is_ok()); + assert!(tree.add_child(&2, 6).is_ok()); // Make assertions - assert_eq!(tree.nearest(&7), 6); - assert_eq!(tree.nearest(&-1), 1); - assert_eq!(tree.nearest(&3), 3); + assert_eq!(tree.nearest_neighbor(&7), &6); + assert_eq!(tree.nearest_neighbor(&-1), &1); + assert_eq!(tree.nearest_neighbor(&3), &3); } #[test] @@ -272,11 +272,11 @@ mod tests { // Construct tree with many nodes let mut tree: Tree = Tree::new(1); - assert!(tree.add_child(1, 2).is_ok()); - assert!(tree.add_child(1, 3).is_ok()); - assert!(tree.add_child(2, 4).is_ok()); - assert!(tree.add_child(2, 5).is_ok()); - assert!(tree.add_child(3, 6).is_ok()); + assert!(tree.add_child(&1, 2).is_ok()); + assert!(tree.add_child(&1, 3).is_ok()); + assert!(tree.add_child(&2, 4).is_ok()); + assert!(tree.add_child(&2, 5).is_ok()); + assert!(tree.add_child(&3, 6).is_ok()); // Expected order let expected_dfs_order = vec![1, 2, 4, 5, 3, 6]; @@ -291,20 +291,20 @@ mod tests { // Construct tree with many nodes let mut tree: Tree = Tree::new(1); - assert!(tree.add_child(1, 2).is_ok()); - assert!(tree.add_child(1, 3).is_ok()); - assert!(tree.add_child(2, 4).is_ok()); - assert!(tree.add_child(2, 5).is_ok()); - assert!(tree.add_child(3, 7).is_ok()); - assert!(tree.add_child(5, 6).is_ok()); + assert!(tree.add_child(&1, 2).is_ok()); + assert!(tree.add_child(&1, 3).is_ok()); + assert!(tree.add_child(&2, 4).is_ok()); + assert!(tree.add_child(&2, 5).is_ok()); + assert!(tree.add_child(&3, 7).is_ok()); + assert!(tree.add_child(&5, 6).is_ok()); - // Verify + // Verify expected paths to different nodes let ep1 = vec![1, 2, 5, 6]; - let cp1: Vec = tree.path(&6).unwrap(); + let cp1 = tree.path(&6).unwrap(); assert_eq!(cp1, ep1); let ep2 = vec![1, 3, 7]; - let cp2: Vec = tree.path(&7).unwrap(); + let cp2 = tree.path(&7).unwrap(); assert_eq!(cp2, ep2); // Invalid node diff --git a/tests/rrt_test.rs b/tests/rrt_test.rs index 2408e7a..5ec535d 100644 --- a/tests/rrt_test.rs +++ b/tests/rrt_test.rs @@ -95,7 +95,7 @@ fn test_rrt() { let success = |p: &Point2D| p.distance(&goal) < success_distance; let result = rrt( - &start, sample, extend, is_valid, success, 10000, // Max iterations + &start, sample, extend, is_valid, success, 10000, ); assert!(result.is_ok(), "Expected Ok result, got Err");