Skip to content

Commit

Permalink
Doing some borrowing cleanups...
Browse files Browse the repository at this point in the history
  • Loading branch information
eholum committed Mar 16, 2024
1 parent 5abe2b7 commit 3be76ce
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 74 deletions.
46 changes: 35 additions & 11 deletions src/planning/rrt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,34 @@ use crate::tree::Distance;
use crate::tree::Tree;
use std::hash::Hash;

/// Attempts to randomly extend the tree in an arbitrary direction.
/// Return the new point and the nearest neighbor, if available.
/// Otherwise return None.
fn extend_tree<T, FS, FE, FV>(
tree: &Tree<T>,
sample: &mut FS,
extend: &mut FE,
is_valid: &mut FV,
) -> Option<(T, T)>
where
T: Eq + Copy + Hash + Distance,
FS: FnMut() -> T,
FE: FnMut(&T, &T) -> T,
FV: FnMut(&T) -> bool,
{
// Sample the grab the nearest point, and extend in that direction
let s = sample();
let nearest = tree.nearest_neighbor(&s);
let new_point = extend(&nearest, &s);

// If it is an invalid point try again
if !is_valid(&new_point) {
return None;
}

Some((new_point, nearest.clone()))
}

/// Basic RRT implementation.
///
/// Will attempt to compute a path using the RRT algorithm given the specified start pose
Expand All @@ -38,7 +66,7 @@ use std::hash::Hash;
/// - `success`: Determines whether or not we have reached the goal
/// - `max_iterations`: Maximum number of random samples to attempt before the search fails
///
/// /// # Returns
/// # Returns
/// Returns a `Result` containing either:
/// - `Ok(Vec<T>)`: A vector of points of type `T` representing the path from the start to a point
/// satisfying the `success` condition, if such a path is found within the given number
Expand Down Expand Up @@ -68,18 +96,14 @@ where
let mut tree = Tree::new(start.clone());

for _ in 0..max_iterations {
// Sample the grab the nearest point, and extend in that direction
let s = sample();
let nearest = tree.nearest(&s);
let new_point = extend(&nearest, &s);

// If it is an invalid point try again
if !is_valid(&new_point) {
continue;
}
let (new_point, nearest) = match extend_tree(&tree, &mut sample, &mut extend, &mut is_valid)
{
Some((new_point, nearest)) => (new_point, nearest),
None => continue,
};

// Otherwise it's valid so add it to the tree
if let Ok(_) = tree.add_child(nearest, new_point) {
if let Ok(_) = tree.add_child(&nearest, new_point) {
// We're good
} else {
// Then the child wasn't added for some reason so just try again
Expand Down
124 changes: 62 additions & 62 deletions src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,26 @@ use std::hash::Hash;
///
/// Must be used with [Tree] since children are referenced by index in the [Tree]'s node vector.
#[derive(Debug)]
struct TreeNode<T> {
struct Node<T> {
// The value of this node.
value: T,

// Maintains a list of pointers to the children's location in the parent's vector
children: Vec<usize>,

// Location of the nodes parent, if available
parent: Option<usize>,

// Maintains a list of pointers to the children's location in the parent's vector.
// Using a vector to maintain order for tree traversals.
children: Vec<usize>,
}

impl<T> TreeNode<T> {
fn new(val: T, parent: Option<usize>) -> Self {
TreeNode {
value: val,
impl<T> Node<T> {
fn new(value: T, parent: Option<usize>) -> Self {
Node {
value: value,
parent: parent,
children: Vec::new(),
}
}

// Add a child to this node's children. Uniqueness must be enforced by the caller.
fn add_child(&mut self, child: usize) {
self.children.push(child);
}
}

/// Define a distance trait for tree node values.
Expand All @@ -62,15 +59,15 @@ pub trait Distance {
/// DFS Iterator for a [Tree]
pub struct DepthFirstIterator<'a, T>
where
T: 'a + Eq + Copy + Distance + Hash,
T: 'a + Eq + Clone + Distance + Hash,
{
tree: &'a Tree<T>,
stack: Vec<usize>,
}

impl<'a, T> DepthFirstIterator<'a, T>
where
T: Eq + Copy + Distance + Hash,
T: Eq + Clone + Distance + Hash,
{
fn new(tree: &'a Tree<T>) -> Self {
let mut stack = Vec::new();
Expand All @@ -84,7 +81,7 @@ where

impl<'a, T> Iterator for DepthFirstIterator<'a, T>
where
T: Eq + Copy + Distance + Hash,
T: Eq + Clone + Distance + Hash,
{
type Item = &'a T;

Expand All @@ -102,38 +99,40 @@ where

/// Basic tree for use in search algorithms.
///
/// Provides functions for creating, growing, and finding the nearest neighbor to `T`.
/// Provides functions for creating, growing, finding the nearest neighbors to `T`,
/// and rewiring the based on cost are provided.
/// Node values must be unique.
///
/// TODO: Make this a KD Tree?
/// TODO: Is a hashmap dumb?
#[derive(Debug)]
pub struct Tree<T>
where
T: Eq + Copy + Distance + Hash,
T: Eq + Clone + Distance + Hash,
{
// Store nodes in a vector to support easy iteration and growth of the tree.
// The root will always be at idx 0.
nodes: Vec<TreeNode<T>>,
// Detailed node data for the tree.
nodes: Vec<Node<T>>,

// Support constant time lookup of nodes by value
// Support constant time lookup of nodes data with a value - node index map.
nodes_map: HashMap<T, usize>,
}

impl<T: Eq + Copy + Distance + Hash> Tree<T> {

impl<T: Eq + Clone + Distance + Hash> Tree<T> {
/// Construct a new tree with the specified value as the root node.
///
/// The node will take ownership of the provided value.
pub fn new(val: T) -> Self {
let root_node = TreeNode::new(val, None);
let mut nodes = Vec::new();
let mut nodes_map = HashMap::new();

// Construct root node and add it to storage
let root_node = Node::new(val.clone(), None);
nodes.push(root_node);
nodes_map.insert(val, 0);

Tree {
nodes: nodes,
nodes_map: nodes_map,
nodes,
nodes_map,
}
}

Expand All @@ -143,7 +142,7 @@ impl<T: Eq + Copy + Distance + Hash> Tree<T> {
///
/// If the parent is not found in the tree.
/// If the child is already in the tree.
pub fn add_child(&mut self, parent: T, child: T) -> Result<(), String> {
pub fn add_child(&mut self, parent: &T, child: T) -> Result<(), String> {
// Cannot duplicate children
if self.nodes_map.contains_key(&child) {
return Err("The child is already in the tree".to_string());
Expand All @@ -152,9 +151,9 @@ impl<T: Eq + Copy + Distance + Hash> Tree<T> {
if let Some(&parent_idx) = self.nodes_map.get(&parent) {
// Append the child node to the nodes vector and note the location in the map.
let child_idx = self.nodes.len();
self.nodes.push(TreeNode::new(child, Some(parent_idx)));
self.nodes.push(Node::new(child.clone(), Some(parent_idx)));
self.nodes_map.insert(child, child_idx);
self.nodes[parent_idx].add_child(child_idx);
self.nodes[parent_idx].children.push(child_idx);
} else {
return Err("The parent cannot be found in the tree".to_string());
}
Expand All @@ -168,8 +167,9 @@ impl<T: Eq + Copy + Distance + Hash> Tree<T> {
}

/// Returns the closest element to the specified value
pub fn nearest(&self, val: &T) -> T {
self.nodes
pub fn nearest_neighbor(&self, val: &T) -> &T {
&self
.nodes
.iter()
.min_by(|a, b| {
let da = val.distance(&a.value);
Expand Down Expand Up @@ -197,12 +197,12 @@ impl<T: Eq + Copy + Distance + Hash> Tree<T> {
}

// Build the path from end to beginning
let mut path: Vec<T> = Vec::new();
let mut path = Vec::new();

// Loop until you get to the root
let mut cur_idx = Some(self.nodes_map[&end]);
while let Some(idx) = cur_idx {
path.push(self.nodes[idx].value);
path.push(self.nodes[idx].value.clone());
cur_idx = self.nodes[idx].parent;
}

Expand All @@ -216,7 +216,7 @@ impl<T: Eq + Copy + Distance + Hash> Tree<T> {
// Unit tests
//

// Needed for distancing points on a plane
// Needed for distancing points on a line
impl Distance for i32 {
fn distance(&self, other: &Self) -> f64 {
(self - other).abs().into()
Expand All @@ -235,48 +235,48 @@ mod tests {
assert_eq!(tree.nodes[0].value, 1);

// Add a child and make sure everything is ok
assert!(tree.add_child(1, 2).is_ok());
assert!(tree.add_child(&1, 2).is_ok());
assert_eq!(tree.size(), 2);

// Make the tree bigger
assert!(tree.add_child(1, 3).is_ok());
assert!(tree.add_child(2, 4).is_ok());
assert!(tree.add_child(&1, 3).is_ok());
assert!(tree.add_child(&2, 4).is_ok());
assert_eq!(tree.size(), 4);

// Add an existing child and everything is not ok
assert!(tree.add_child(1, 2).is_err());
assert!(tree.add_child(&1, 2).is_err());

// Add to a nonexistent parent and everything is not ok
assert!(tree.add_child(3, 2).is_err());
assert!(tree.add_child(&3, 2).is_err());
}

#[test]
fn test_tree_get_nearest() {
// Construct tree with many nodes
let mut tree: Tree<i32> = Tree::new(1);

assert!(tree.add_child(1, 2).is_ok());
assert!(tree.add_child(1, 3).is_ok());
assert!(tree.add_child(2, 4).is_ok());
assert!(tree.add_child(2, 5).is_ok());
assert!(tree.add_child(2, 6).is_ok());
assert!(tree.add_child(&1, 2).is_ok());
assert!(tree.add_child(&1, 3).is_ok());
assert!(tree.add_child(&2, 4).is_ok());
assert!(tree.add_child(&2, 5).is_ok());
assert!(tree.add_child(&2, 6).is_ok());

// Make assertions
assert_eq!(tree.nearest(&7), 6);
assert_eq!(tree.nearest(&-1), 1);
assert_eq!(tree.nearest(&3), 3);
assert_eq!(tree.nearest_neighbor(&7), &6);
assert_eq!(tree.nearest_neighbor(&-1), &1);
assert_eq!(tree.nearest_neighbor(&3), &3);
}

#[test]
fn test_tree_dfs() {
// Construct tree with many nodes
let mut tree: Tree<i32> = Tree::new(1);

assert!(tree.add_child(1, 2).is_ok());
assert!(tree.add_child(1, 3).is_ok());
assert!(tree.add_child(2, 4).is_ok());
assert!(tree.add_child(2, 5).is_ok());
assert!(tree.add_child(3, 6).is_ok());
assert!(tree.add_child(&1, 2).is_ok());
assert!(tree.add_child(&1, 3).is_ok());
assert!(tree.add_child(&2, 4).is_ok());
assert!(tree.add_child(&2, 5).is_ok());
assert!(tree.add_child(&3, 6).is_ok());

// Expected order
let expected_dfs_order = vec![1, 2, 4, 5, 3, 6];
Expand All @@ -291,20 +291,20 @@ mod tests {
// Construct tree with many nodes
let mut tree: Tree<i32> = Tree::new(1);

assert!(tree.add_child(1, 2).is_ok());
assert!(tree.add_child(1, 3).is_ok());
assert!(tree.add_child(2, 4).is_ok());
assert!(tree.add_child(2, 5).is_ok());
assert!(tree.add_child(3, 7).is_ok());
assert!(tree.add_child(5, 6).is_ok());
assert!(tree.add_child(&1, 2).is_ok());
assert!(tree.add_child(&1, 3).is_ok());
assert!(tree.add_child(&2, 4).is_ok());
assert!(tree.add_child(&2, 5).is_ok());
assert!(tree.add_child(&3, 7).is_ok());
assert!(tree.add_child(&5, 6).is_ok());

// Verify
// Verify expected paths to different nodes
let ep1 = vec![1, 2, 5, 6];
let cp1: Vec<i32> = tree.path(&6).unwrap();
let cp1 = tree.path(&6).unwrap();
assert_eq!(cp1, ep1);

let ep2 = vec![1, 3, 7];
let cp2: Vec<i32> = tree.path(&7).unwrap();
let cp2 = tree.path(&7).unwrap();
assert_eq!(cp2, ep2);

// Invalid node
Expand Down
2 changes: 1 addition & 1 deletion tests/rrt_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ fn test_rrt() {
let success = |p: &Point2D| p.distance(&goal) < success_distance;

let result = rrt(
&start, sample, extend, is_valid, success, 10000, // Max iterations
&start, sample, extend, is_valid, success, 10000,
);

assert!(result.is_ok(), "Expected Ok result, got Err");
Expand Down

0 comments on commit 3be76ce

Please sign in to comment.