diff --git a/examples/world_example.rs b/examples/world_example.rs index 5be7ffa..4e970ab 100644 --- a/examples/world_example.rs +++ b/examples/world_example.rs @@ -25,7 +25,7 @@ use ordered_float::OrderedFloat; use plotly::common::{Fill, Line as PlotlyLine, Mode}; use plotly::{Layout, Plot, Scatter}; use rand::Rng; -use rustplanning::planning::rrt::{rrt, rrtstar}; +use rustplanning::planning::rrt::rrt; use rustplanning::tree::{Distance, HashTree}; use std::env; @@ -242,32 +242,23 @@ pub fn main() { let is_valid_fn = |start: &RobotPose, end: &RobotPose| world.is_valid(start, end, buffer); let success_fn = |pose: &RobotPose| pose.distance(&end) <= valid_distance; - let result; - let alg; - if use_rrtstar { + let alg = if use_rrtstar { println!("Finding path with RRT*"); - alg = "RRT*"; - result = rrtstar( - &start, - sample_fn, - extend_fn, - is_valid_fn, - success_fn, - rewire_radius, - 100000, - ); + "RRT*" } else { println!("Finding path with RRT"); - alg = "RRT"; - result = rrt( - &start, - sample_fn, - extend_fn, - is_valid_fn, - success_fn, - 100000, - ); - } + "RRT" + }; + let result = rrt( + &start, + sample_fn, + extend_fn, + is_valid_fn, + success_fn, + use_rrtstar, + rewire_radius, + 100000, + ); match result { Ok((path, tree)) => { println!("Path found!"); diff --git a/src/planning/rrt.rs b/src/planning/rrt.rs index 49e1075..e59d7ff 100644 --- a/src/planning/rrt.rs +++ b/src/planning/rrt.rs @@ -52,9 +52,35 @@ where Some((new_point, nearest.clone())) } -/// Basic RRT implementation. +fn rewire_tree ( + tree: &mut HashTree, + is_valid: &mut FV, + point: &T, + rewire_radius: f64, +) +where + T: Eq + Copy + Hash + Distance, + FV: FnMut(&T, &T) -> bool, +{ + // Get a list of all nodes that are within the sample radius, and rewire if necessary + let neighbors = tree.nearest_neighbors(point, rewire_radius); + let new_cost = tree.cost(point).unwrap(); + for (neighbor, distance) in neighbors.iter() { + if neighbor == point { + continue; + } + // If it's cheaper and valid to get to the neighbor from the new node reparent it + if distance + new_cost < tree.cost(neighbor).unwrap() { + if is_valid(point, neighbor) { + let _ = tree.set_parent(neighbor, point); + } + } + } +} + +/// Implementation of RRT planning algorithms. /// -/// Will attempt to compute a path using the RRT algorithm given the specified start pose +/// Will attempt to compute a path using the specified version of RRT given the start pose /// and user-defined coverage functions. /// /// # Parameters @@ -64,6 +90,8 @@ where /// - `extend`: Given two nodes, function to return an intermediate value between them /// - `is_valid`: Function to determine whether or not a link can be added between two nodes /// - `success`: Returns whether or not a node has reached the goal +/// - `use_rrtstar`: Whether or not to use RRT* +/// - `rewire_radius`: If using RRT*, the max distance to identify and rewire neighbors of newly added nodes /// - `max_iterations`: Maximum number of random samples to attempt before the search fails /// /// # Returns @@ -84,6 +112,8 @@ pub fn rrt( mut extend: FE, mut is_valid: FV, mut success: FD, + use_rrtstar: bool, + rewire_radius: f64, max_iterations: u64, ) -> Result<(Vec, HashTree), String> where @@ -96,6 +126,7 @@ where let mut tree = HashTree::new(start.clone()); for _ in 0..max_iterations { + // Sample the grab the nearest point, and extend in that direction let (new_point, nearest) = match extend_tree(&tree, &mut sample, &mut extend, &mut is_valid) { Some((new_point, nearest)) => (new_point, nearest), @@ -107,73 +138,8 @@ where continue; } - // Are we there yet? If so return the path. - if success(&new_point) { - match tree.path(&new_point) { - Ok(path) => return Ok((path, tree)), - Err(e) => return Err(e), - } - } - } - - // Otherwise we've hit max_iter with finding success - Err("Failed to find a path".to_string()) -} - -/// Basic implementation for RRTStar. -/// -/// Method signature is nearly identical to [`rrt`], though includes a radius for -/// rewiring neighbors of sampled nodes. -/// -/// # Parameters -/// -/// Are the same as `rrt` excepting for, -/// -/// - `rewire_radius`: The max distance to identify and rewire neighbors of newly added nodes -/// -pub fn rrtstar( - start: &T, - mut sample: FS, - mut extend: FE, - mut is_valid: FV, - mut success: FD, - rewire_radius: f64, - max_iterations: u64, -) -> Result<(Vec, HashTree), String> -where - T: Eq + Copy + Hash + Distance, - FS: FnMut() -> T, - FE: FnMut(&T, &T) -> T, - FV: FnMut(&T, &T) -> bool, - FD: FnMut(&T) -> bool, -{ - let mut tree = HashTree::new(start.clone()); - - for _ in 0..max_iterations { - // Sample the grab the nearest point, and extend in that direction - let (new_point, nearest) = match extend_tree(&tree, &mut sample, &mut extend, &mut is_valid) - { - Some((new_point, nearest)) => (new_point, nearest), - None => continue, - }; - - // If it's valid add it to the tree - tree.add_child(&nearest, new_point).unwrap(); - let new_cost = tree.cost(&new_point).unwrap(); - - // Get a list of all nodes that are within the sample radius, and rewire if necessary - let neighbors = tree.nearest_neighbors(&new_point, rewire_radius); - for (neighbor, distance) in neighbors.iter() { - if neighbor == &new_point { - continue; - } - if !is_valid(&new_point, neighbor) { - continue; - } - // If it's cheaper to get to the neighbor from the new node reparent it - if distance + new_cost < tree.cost(neighbor).unwrap() { - tree.set_parent(neighbor, &new_point)?; - } + if use_rrtstar { + rewire_tree(&mut tree, &mut is_valid, &new_point, rewire_radius); } // Are we there yet? If so return the path. diff --git a/src/tree.rs b/src/tree.rs index ebae6da..74d38b0 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -247,7 +247,7 @@ impl HashTree { /// Finds all nodes that are within the specified radius and returns a map of /// all closest elements and their values. - pub fn nearest_neighbors(&mut self, val: &T, radius: f64) -> HashMap { + pub fn nearest_neighbors(&self, val: &T, radius: f64) -> HashMap { // First iterate over all nodes to identify all neighbors let mut neighbors = HashMap::new(); for (_i, check) in self.nodes.iter().enumerate() { diff --git a/tests/rrt_test.rs b/tests/rrt_test.rs index 7b27c9d..2bd3771 100644 --- a/tests/rrt_test.rs +++ b/tests/rrt_test.rs @@ -22,7 +22,7 @@ use ordered_float::OrderedFloat; use rand::Rng; -use rustplanning::planning::rrt::{rrt, rrtstar}; +use rustplanning::planning::rrt::rrt; use rustplanning::tree::Distance; use std::fmt; @@ -92,13 +92,16 @@ fn run_rrt_test(use_rrtstar: bool) { let is_valid = |_: &Point2D, end: &Point2D| end.distance(&obstacle) > 1.0; let success = |p: &Point2D| p.distance(&goal) < success_distance; - let result; - if use_rrtstar { - result = rrtstar(&start, sample, extend, is_valid, success, 0.2, 10000); - } - else { - result = rrt(&start, sample, extend, is_valid, success, 10000); - } + let result = rrt( + &start, + sample, + extend, + is_valid, + success, + use_rrtstar, + 0.2, + 10000, + ); assert!(result.is_ok(), "Expected Ok result, got Err"); @@ -112,7 +115,6 @@ fn run_rrt_test(use_rrtstar: bool) { end.distance(&goal) < success_distance, "Path should end near the goal" ); - } #[test]