Skip to content

Commit

Permalink
Iterator only needs a node reference (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsserge authored Nov 11, 2024
1 parent 745eb63 commit 1a46f59
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 29 deletions.
3 changes: 1 addition & 2 deletions src/art.rs
Original file line number Diff line number Diff line change
Expand Up @@ -890,8 +890,7 @@ impl<P: KeyTrait, V: Clone> Node<P, V> {
///
/// Returns a boxed iterator that yields tuples containing keys and references to child nodes.
///
#[allow(dead_code)]
pub(crate) fn iter(&self) -> Box<dyn DoubleEndedIterator<Item = (u8, &Arc<Self>)> + '_> {
pub(crate) fn iter(&self) -> Box<dyn DoubleEndedIterator<Item = &Arc<Self>> + '_> {
match &self.node_type {
NodeType::Node4(n) => Box::new(n.iter()),
NodeType::Node16(n) => Box::new(n.iter()),
Expand Down
22 changes: 11 additions & 11 deletions src/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::art::{Node, NodeType, QueryType};
use crate::node::{LeafValue, TwigNode};
use crate::KeyTrait;

type NodeIterator<'a, P, V> = Box<dyn DoubleEndedIterator<Item = (u8, &'a Arc<Node<P, V>>)> + 'a>;
type NodeIterator<'a, P, V> = Box<dyn DoubleEndedIterator<Item = &'a Arc<Node<P, V>>> + 'a>;

/// An iterator over the nodes in the Trie.
struct NodeIter<'a, P: KeyTrait, V: Clone> {
Expand All @@ -22,7 +22,7 @@ impl<'a, P: KeyTrait, V: Clone> NodeIter<'a, P, V> {
///
fn new<I>(iter: I) -> Self
where
I: DoubleEndedIterator<Item = (u8, &'a Arc<Node<P, V>>)> + 'a,
I: DoubleEndedIterator<Item = &'a Arc<Node<P, V>>> + 'a,
{
Self {
node: Box::new(iter),
Expand All @@ -31,7 +31,7 @@ impl<'a, P: KeyTrait, V: Clone> NodeIter<'a, P, V> {
}

impl<'a, P: KeyTrait, V: Clone> Iterator for NodeIter<'a, P, V> {
type Item = (u8, &'a Arc<Node<P, V>>);
type Item = &'a Arc<Node<P, V>>;

fn next(&mut self) -> Option<Self::Item> {
self.node.next()
Expand Down Expand Up @@ -114,7 +114,7 @@ impl<'a, P: KeyTrait + 'a, V: Clone> Iterator for Iter<'a, P, V> {
self.forward.iters.pop();
}
Some(other) => {
if let NodeType::Twig(twig) = &other.1.node_type {
if let NodeType::Twig(twig) = &other.node_type {
if self.forward.is_versioned {
for leaf in twig.iter() {
self.forward.leafs.push_back(Leaf(&twig.key, leaf));
Expand All @@ -124,7 +124,7 @@ impl<'a, P: KeyTrait + 'a, V: Clone> Iterator for Iter<'a, P, V> {
}
break;
} else {
self.forward.iters.push(NodeIter::new(other.1.iter()));
self.forward.iters.push(NodeIter::new(other.iter()));
}
}
}
Expand Down Expand Up @@ -161,7 +161,7 @@ impl<'a, P: KeyTrait + 'a, V: Clone> DoubleEndedIterator for Iter<'a, P, V> {
self.backward.iters.pop();
}
Some(other) => {
if let NodeType::Twig(twig) = &other.1.node_type {
if let NodeType::Twig(twig) = &other.node_type {
if self.backward.is_versioned {
for leaf in twig.iter() {
self.backward.leafs.push(Leaf(&twig.key, leaf));
Expand All @@ -171,7 +171,7 @@ impl<'a, P: KeyTrait + 'a, V: Clone> DoubleEndedIterator for Iter<'a, P, V> {
}
break;
} else {
self.backward.iters.push(NodeIter::new(other.1.iter()));
self.backward.iters.push(NodeIter::new(other.iter()));
}
}
}
Expand Down Expand Up @@ -498,7 +498,7 @@ impl<'a, K: 'a + KeyTrait, V: Clone, R: RangeBounds<K>> Iterator for Range<'a, K
fn next(&mut self) -> Option<Self::Item> {
while let Some(node) = self.forward.iters.last_mut() {
if let Some(other) = node.next() {
if let NodeType::Twig(twig) = &other.1.node_type {
if let NodeType::Twig(twig) = &other.node_type {
if self.range.contains(&twig.key) {
self.handle_twig(twig);
break;
Expand All @@ -510,7 +510,7 @@ impl<'a, K: 'a + KeyTrait, V: Clone, R: RangeBounds<K>> Iterator for Range<'a, K
&mut self.prefix,
&mut self.prefix_lengths,
&self.range,
other.1,
other,
&mut self.forward.iters,
);
}
Expand Down Expand Up @@ -589,7 +589,7 @@ where
let e = node.next();
match e {
Some(other) => {
if let NodeType::Twig(twig) = &other.1.node_type {
if let NodeType::Twig(twig) = &other.node_type {
if range.contains(&twig.key) {
// Iterate through leaves of the twig
if let Some(leaf) = twig.get_leaf_by_query(query_type) {
Expand All @@ -609,7 +609,7 @@ where
&mut prefix,
&mut prefix_lengths,
&range,
other.1,
other,
&mut forward.iters,
);
}
Expand Down
27 changes: 11 additions & 16 deletions src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,11 @@ impl<P: KeyTrait, N, const WIDTH: usize> FlatNode<P, N, WIDTH> {
}

#[inline]
pub(crate) fn iter(&self) -> impl DoubleEndedIterator<Item = (u8, &Arc<N>)> {
self.keys
pub(crate) fn iter(&self) -> impl DoubleEndedIterator<Item = &Arc<N>> {
self.children
.iter()
.zip(self.children.iter())
.take(self.num_children as usize)
.filter_map(|(&k, c)| c.as_ref().map(|child| (k, child)))
.filter_map(|child| child.as_ref())
}
}

Expand Down Expand Up @@ -442,12 +441,11 @@ impl<P: KeyTrait, N> Node48<P, N> {
n256
}

pub(crate) fn iter(&self) -> impl DoubleEndedIterator<Item = (u8, &Arc<N>)> {
pub(crate) fn iter(&self) -> impl DoubleEndedIterator<Item = &Arc<N>> {
self.keys
.iter()
.enumerate()
.filter(|(_, x)| **x != u8::MAX)
.map(move |(key, pos)| (key as u8, self.children[*pos as usize].as_ref().unwrap()))
.filter(|key| **key != u8::MAX)
.map(move |pos| self.children[*pos as usize].as_ref().unwrap())
}
}

Expand Down Expand Up @@ -558,11 +556,8 @@ impl<P: KeyTrait, N> Node256<P, N> {
self.num_children += new_insert as usize;
}

pub(crate) fn iter(&self) -> impl DoubleEndedIterator<Item = (u8, &Arc<N>)> {
self.children
.iter()
.enumerate()
.filter_map(|(key, node)| node.as_ref().map(|x| (key as u8, x)))
pub(crate) fn iter(&self) -> impl DoubleEndedIterator<Item = &Arc<N>> {
self.children.iter().filter_map(|node| node.as_ref())
}
}

Expand Down Expand Up @@ -885,7 +880,7 @@ mod tests {
}

for child in node.iter() {
assert_eq!(Arc::strong_count(child.1), 1);
assert_eq!(Arc::strong_count(child), 1);
}

// Create and test Node48
Expand All @@ -895,7 +890,7 @@ mod tests {
}

for child in n48.iter() {
assert_eq!(Arc::strong_count(child.1), 1);
assert_eq!(Arc::strong_count(child), 1);
}

// Create and test Node256
Expand All @@ -905,7 +900,7 @@ mod tests {
}

for child in n256.iter() {
assert_eq!(Arc::strong_count(child.1), 1);
assert_eq!(Arc::strong_count(child), 1);
}
}

Expand Down

0 comments on commit 1a46f59

Please sign in to comment.