Skip to content

Commit

Permalink
Unify the RRT interface and remove duplicate code
Browse files Browse the repository at this point in the history
  • Loading branch information
eholum committed Mar 21, 2024
1 parent 8888193 commit 382c088
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 103 deletions.
39 changes: 15 additions & 24 deletions examples/world_example.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use ordered_float::OrderedFloat;
use plotly::common::{Fill, Line as PlotlyLine, Mode};
use plotly::{Layout, Plot, Scatter};
use rand::Rng;
use rustplanning::planning::rrt::{rrt, rrtstar};
use rustplanning::planning::rrt::rrt;
use rustplanning::tree::{Distance, HashTree};
use std::env;

Expand Down Expand Up @@ -242,32 +242,23 @@ pub fn main() {
let is_valid_fn = |start: &RobotPose, end: &RobotPose| world.is_valid(start, end, buffer);
let success_fn = |pose: &RobotPose| pose.distance(&end) <= valid_distance;

let result;
let alg;
if use_rrtstar {
let alg = if use_rrtstar {
println!("Finding path with RRT*");
alg = "RRT*";
result = rrtstar(
&start,
sample_fn,
extend_fn,
is_valid_fn,
success_fn,
rewire_radius,
100000,
);
"RRT*"
} else {
println!("Finding path with RRT");
alg = "RRT";
result = rrt(
&start,
sample_fn,
extend_fn,
is_valid_fn,
success_fn,
100000,
);
}
"RRT"
};
let result = rrt(
&start,
sample_fn,
extend_fn,
is_valid_fn,
success_fn,
use_rrtstar,
rewire_radius,
100000,
);
match result {
Ok((path, tree)) => {
println!("Path found!");
Expand Down
104 changes: 35 additions & 69 deletions src/planning/rrt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,35 @@ where
Some((new_point, nearest.clone()))
}

/// Basic RRT implementation.
fn rewire_tree<T, FV> (
tree: &mut HashTree<T>,
is_valid: &mut FV,
point: &T,
rewire_radius: f64,
)
where
T: Eq + Copy + Hash + Distance,
FV: FnMut(&T, &T) -> bool,
{
// Get a list of all nodes that are within the sample radius, and rewire if necessary
let neighbors = tree.nearest_neighbors(point, rewire_radius);
let new_cost = tree.cost(point).unwrap();
for (neighbor, distance) in neighbors.iter() {
if neighbor == point {
continue;
}
// If it's cheaper and valid to get to the neighbor from the new node reparent it
if distance + new_cost < tree.cost(neighbor).unwrap() {
if is_valid(point, neighbor) {
let _ = tree.set_parent(neighbor, point);
}
}
}
}

/// Implementation of RRT planning algorithms.
///
/// Will attempt to compute a path using the RRT algorithm given the specified start pose
/// Will attempt to compute a path using the specified version of RRT given the start pose
/// and user-defined coverage functions.
///
/// # Parameters
Expand All @@ -64,6 +90,8 @@ where
/// - `extend`: Given two nodes, function to return an intermediate value between them
/// - `is_valid`: Function to determine whether or not a link can be added between two nodes
/// - `success`: Returns whether or not a node has reached the goal
/// - `use_rrtstar`: Whether or not to use RRT*
/// - `rewire_radius`: If using RRT*, the max distance to identify and rewire neighbors of newly added nodes
/// - `max_iterations`: Maximum number of random samples to attempt before the search fails
///
/// # Returns
Expand All @@ -84,6 +112,8 @@ pub fn rrt<T, FS, FE, FV, FD>(
mut extend: FE,
mut is_valid: FV,
mut success: FD,
use_rrtstar: bool,
rewire_radius: f64,
max_iterations: u64,
) -> Result<(Vec<T>, HashTree<T>), String>
where
Expand All @@ -96,6 +126,7 @@ where
let mut tree = HashTree::new(start.clone());

for _ in 0..max_iterations {
// Sample the grab the nearest point, and extend in that direction
let (new_point, nearest) = match extend_tree(&tree, &mut sample, &mut extend, &mut is_valid)
{
Some((new_point, nearest)) => (new_point, nearest),
Expand All @@ -107,73 +138,8 @@ where
continue;
}

// Are we there yet? If so return the path.
if success(&new_point) {
match tree.path(&new_point) {
Ok(path) => return Ok((path, tree)),
Err(e) => return Err(e),
}
}
}

// Otherwise we've hit max_iter with finding success
Err("Failed to find a path".to_string())
}

/// Basic implementation for RRTStar.
///
/// Method signature is nearly identical to [`rrt`], though includes a radius for
/// rewiring neighbors of sampled nodes.
///
/// # Parameters
///
/// Are the same as `rrt` excepting for,
///
/// - `rewire_radius`: The max distance to identify and rewire neighbors of newly added nodes
///
pub fn rrtstar<T, FS, FE, FV, FD>(
start: &T,
mut sample: FS,
mut extend: FE,
mut is_valid: FV,
mut success: FD,
rewire_radius: f64,
max_iterations: u64,
) -> Result<(Vec<T>, HashTree<T>), String>
where
T: Eq + Copy + Hash + Distance,
FS: FnMut() -> T,
FE: FnMut(&T, &T) -> T,
FV: FnMut(&T, &T) -> bool,
FD: FnMut(&T) -> bool,
{
let mut tree = HashTree::new(start.clone());

for _ in 0..max_iterations {
// Sample the grab the nearest point, and extend in that direction
let (new_point, nearest) = match extend_tree(&tree, &mut sample, &mut extend, &mut is_valid)
{
Some((new_point, nearest)) => (new_point, nearest),
None => continue,
};

// If it's valid add it to the tree
tree.add_child(&nearest, new_point).unwrap();
let new_cost = tree.cost(&new_point).unwrap();

// Get a list of all nodes that are within the sample radius, and rewire if necessary
let neighbors = tree.nearest_neighbors(&new_point, rewire_radius);
for (neighbor, distance) in neighbors.iter() {
if neighbor == &new_point {
continue;
}
if !is_valid(&new_point, neighbor) {
continue;
}
// If it's cheaper to get to the neighbor from the new node reparent it
if distance + new_cost < tree.cost(neighbor).unwrap() {
tree.set_parent(neighbor, &new_point)?;
}
if use_rrtstar {
rewire_tree(&mut tree, &mut is_valid, &new_point, rewire_radius);
}

// Are we there yet? If so return the path.
Expand Down
2 changes: 1 addition & 1 deletion src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ impl<T: Eq + Clone + Distance + Hash> HashTree<T> {

/// Finds all nodes that are within the specified radius and returns a map of
/// all closest elements and their values.
pub fn nearest_neighbors(&mut self, val: &T, radius: f64) -> HashMap<T, f64> {
pub fn nearest_neighbors(&self, val: &T, radius: f64) -> HashMap<T, f64> {
// First iterate over all nodes to identify all neighbors
let mut neighbors = HashMap::new();
for (_i, check) in self.nodes.iter().enumerate() {
Expand Down
20 changes: 11 additions & 9 deletions tests/rrt_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

use ordered_float::OrderedFloat;
use rand::Rng;
use rustplanning::planning::rrt::{rrt, rrtstar};
use rustplanning::planning::rrt::rrt;
use rustplanning::tree::Distance;
use std::fmt;

Expand Down Expand Up @@ -92,13 +92,16 @@ fn run_rrt_test(use_rrtstar: bool) {
let is_valid = |_: &Point2D, end: &Point2D| end.distance(&obstacle) > 1.0;
let success = |p: &Point2D| p.distance(&goal) < success_distance;

let result;
if use_rrtstar {
result = rrtstar(&start, sample, extend, is_valid, success, 0.2, 10000);
}
else {
result = rrt(&start, sample, extend, is_valid, success, 10000);
}
let result = rrt(
&start,
sample,
extend,
is_valid,
success,
use_rrtstar,
0.2,
10000,
);

assert!(result.is_ok(), "Expected Ok result, got Err");

Expand All @@ -112,7 +115,6 @@ fn run_rrt_test(use_rrtstar: bool) {
end.distance(&goal) < success_distance,
"Path should end near the goal"
);

}

#[test]
Expand Down

0 comments on commit 382c088

Please sign in to comment.