reth_trie/
walker.rs

1use crate::{
2    prefix_set::PrefixSet,
3    trie_cursor::{subnode::SubNodePosition, CursorSubNode, TrieCursor},
4    BranchNodeCompact, Nibbles,
5};
6use alloy_primitives::{map::HashSet, B256};
7use reth_storage_errors::db::DatabaseError;
8use tracing::{instrument, trace};
9
10#[cfg(feature = "metrics")]
11use crate::metrics::WalkerMetrics;
12
13/// `TrieWalker` is a structure that enables traversal of a Merkle trie.
14/// It allows moving through the trie in a depth-first manner, skipping certain branches
15/// if they have not changed.
16#[derive(Debug)]
17pub struct TrieWalker<C> {
18    /// A mutable reference to a trie cursor instance used for navigating the trie.
19    pub cursor: C,
20    /// A vector containing the trie nodes that have been visited.
21    pub stack: Vec<CursorSubNode>,
22    /// A flag indicating whether the current node can be skipped when traversing the trie. This
23    /// is determined by whether the current key's prefix is included in the prefix set and if the
24    /// hash flag is set.
25    pub can_skip_current_node: bool,
26    /// A `PrefixSet` representing the changes to be applied to the trie.
27    pub changes: PrefixSet,
28    /// The retained trie node keys that need to be removed.
29    removed_keys: Option<HashSet<Nibbles>>,
30    #[cfg(feature = "metrics")]
31    /// Walker metrics.
32    metrics: WalkerMetrics,
33}
34
35impl<C> TrieWalker<C> {
36    /// Constructs a new `TrieWalker` from existing stack and a cursor.
37    pub fn from_stack(cursor: C, stack: Vec<CursorSubNode>, changes: PrefixSet) -> Self {
38        let mut this = Self {
39            cursor,
40            changes,
41            stack,
42            can_skip_current_node: false,
43            removed_keys: None,
44            #[cfg(feature = "metrics")]
45            metrics: WalkerMetrics::default(),
46        };
47        this.update_skip_node();
48        this
49    }
50
51    /// Sets the flag whether the trie updates should be stored.
52    pub fn with_deletions_retained(mut self, retained: bool) -> Self {
53        if retained {
54            self.removed_keys = Some(HashSet::default());
55        }
56        self
57    }
58
59    /// Split the walker into stack and trie updates.
60    pub fn split(mut self) -> (Vec<CursorSubNode>, HashSet<Nibbles>) {
61        let keys = self.take_removed_keys();
62        (self.stack, keys)
63    }
64
65    /// Take removed keys from the walker.
66    pub fn take_removed_keys(&mut self) -> HashSet<Nibbles> {
67        self.removed_keys.take().unwrap_or_default()
68    }
69
70    /// Prints the current stack of trie nodes.
71    pub fn print_stack(&self) {
72        println!("====================== STACK ======================");
73        for node in &self.stack {
74            println!("{node:?}");
75        }
76        println!("====================== END STACK ======================\n");
77    }
78
79    /// The current length of the removed keys.
80    pub fn removed_keys_len(&self) -> usize {
81        self.removed_keys.as_ref().map_or(0, |u| u.len())
82    }
83
84    /// Returns the current key in the trie.
85    pub fn key(&self) -> Option<&Nibbles> {
86        self.stack.last().map(|n| n.full_key())
87    }
88
89    /// Returns the current hash in the trie if any.
90    pub fn hash(&self) -> Option<B256> {
91        self.stack.last().and_then(|n| n.hash())
92    }
93
94    /// Indicates whether the children of the current node are present in the trie.
95    pub fn children_are_in_trie(&self) -> bool {
96        self.stack.last().is_some_and(|n| n.tree_flag())
97    }
98
99    /// Returns the next unprocessed key in the trie along with its raw [`Nibbles`] representation.
100    #[instrument(level = "trace", skip(self), ret)]
101    pub fn next_unprocessed_key(&self) -> Option<(B256, Nibbles)> {
102        self.key()
103            .and_then(
104                |key| if self.can_skip_current_node { key.increment() } else { Some(key.clone()) },
105            )
106            .map(|key| {
107                let mut packed = key.pack();
108                packed.resize(32, 0);
109                (B256::from_slice(packed.as_slice()), key)
110            })
111    }
112
113    /// Updates the skip node flag based on the walker's current state.
114    fn update_skip_node(&mut self) {
115        let old = self.can_skip_current_node;
116        self.can_skip_current_node = self
117            .stack
118            .last()
119            .is_some_and(|node| !self.changes.contains(node.full_key()) && node.hash_flag());
120        trace!(
121            target: "trie::walker",
122            old,
123            new = self.can_skip_current_node,
124            last = ?self.stack.last(),
125            "updated skip node flag"
126        );
127    }
128}
129
130impl<C: TrieCursor> TrieWalker<C> {
131    /// Constructs a new `TrieWalker`, setting up the initial state of the stack and cursor.
132    pub fn new(cursor: C, changes: PrefixSet) -> Self {
133        // Initialize the walker with a single empty stack element.
134        let mut this = Self {
135            cursor,
136            changes,
137            stack: vec![CursorSubNode::default()],
138            can_skip_current_node: false,
139            removed_keys: None,
140            #[cfg(feature = "metrics")]
141            metrics: WalkerMetrics::default(),
142        };
143
144        // Set up the root node of the trie in the stack, if it exists.
145        if let Some((key, value)) = this.node(true).unwrap() {
146            this.stack[0] = CursorSubNode::new(key, Some(value));
147        }
148
149        // Update the skip state for the root node.
150        this.update_skip_node();
151        this
152    }
153
154    /// Advances the walker to the next trie node and updates the skip node flag.
155    /// The new key can then be obtained via `key()`.
156    ///
157    /// # Returns
158    ///
159    /// * `Result<(), Error>` - Unit on success or an error.
160    pub fn advance(&mut self) -> Result<(), DatabaseError> {
161        if let Some(last) = self.stack.last() {
162            if !self.can_skip_current_node && self.children_are_in_trie() {
163                trace!(
164                    target: "trie::walker",
165                    position = ?last.position(),
166                    "cannot skip current node and children are in the trie"
167                );
168                // If we can't skip the current node and the children are in the trie,
169                // either consume the next node or move to the next sibling.
170                match last.position() {
171                    SubNodePosition::ParentBranch => self.move_to_next_sibling(true)?,
172                    SubNodePosition::Child(_) => self.consume_node()?,
173                }
174            } else {
175                trace!(target: "trie::walker", "can skip current node");
176                // If we can skip the current node, move to the next sibling.
177                self.move_to_next_sibling(false)?;
178            }
179
180            // Update the skip node flag based on the new position in the trie.
181            self.update_skip_node();
182        }
183
184        Ok(())
185    }
186
187    /// Retrieves the current root node from the DB, seeking either the exact node or the next one.
188    fn node(&mut self, exact: bool) -> Result<Option<(Nibbles, BranchNodeCompact)>, DatabaseError> {
189        let key = self.key().expect("key must exist").clone();
190        let entry = if exact { self.cursor.seek_exact(key)? } else { self.cursor.seek(key)? };
191
192        if let Some((_, node)) = &entry {
193            assert!(!node.state_mask.is_empty());
194        }
195
196        Ok(entry)
197    }
198
199    /// Consumes the next node in the trie, updating the stack.
200    #[instrument(level = "trace", skip(self), ret)]
201    fn consume_node(&mut self) -> Result<(), DatabaseError> {
202        let Some((key, node)) = self.node(false)? else {
203            // If no next node is found, clear the stack.
204            self.stack.clear();
205            return Ok(())
206        };
207
208        // Overwrite the root node's first nibble
209        // We need to sync the stack with the trie structure when consuming a new node. This is
210        // necessary for proper traversal and accurately representing the trie in the stack.
211        if !key.is_empty() && !self.stack.is_empty() {
212            self.stack[0].set_nibble(key[0]);
213        }
214
215        // The current tree mask might have been set incorrectly.
216        // Sanity check that the newly retrieved trie node key is the child of the last item
217        // on the stack. If not, advance to the next sibling instead of adding the node to the
218        // stack.
219        if let Some(subnode) = self.stack.last() {
220            if !key.starts_with(subnode.full_key()) {
221                #[cfg(feature = "metrics")]
222                self.metrics.inc_out_of_order_subnode(1);
223                self.move_to_next_sibling(false)?;
224                return Ok(())
225            }
226        }
227
228        // Create a new CursorSubNode and push it to the stack.
229        let subnode = CursorSubNode::new(key, Some(node));
230        let position = subnode.position();
231        self.stack.push(subnode);
232        self.update_skip_node();
233
234        // Delete the current node if it's included in the prefix set or it doesn't contain the root
235        // hash.
236        if !self.can_skip_current_node || position.is_child() {
237            if let Some((keys, key)) = self.removed_keys.as_mut().zip(self.cursor.current()?) {
238                keys.insert(key);
239            }
240        }
241
242        Ok(())
243    }
244
245    /// Moves to the next sibling node in the trie, updating the stack.
246    #[instrument(level = "trace", skip(self), ret)]
247    fn move_to_next_sibling(
248        &mut self,
249        allow_root_to_child_nibble: bool,
250    ) -> Result<(), DatabaseError> {
251        let Some(subnode) = self.stack.last_mut() else { return Ok(()) };
252
253        // Check if the walker needs to backtrack to the previous level in the trie during its
254        // traversal.
255        if subnode.position().is_last_child() ||
256            (subnode.position().is_parent() && !allow_root_to_child_nibble)
257        {
258            self.stack.pop();
259            self.move_to_next_sibling(false)?;
260            return Ok(())
261        }
262
263        subnode.inc_nibble();
264
265        if subnode.node.is_none() {
266            return self.consume_node()
267        }
268
269        // Find the next sibling with state.
270        loop {
271            let position = subnode.position();
272            if subnode.state_flag() {
273                trace!(target: "trie::walker", ?position, "found next sibling with state");
274                return Ok(())
275            }
276            if position.is_last_child() {
277                trace!(target: "trie::walker", ?position, "checked all siblings");
278                break
279            }
280            subnode.inc_nibble();
281        }
282
283        // Pop the current node and move to the next sibling.
284        self.stack.pop();
285        self.move_to_next_sibling(false)?;
286
287        Ok(())
288    }
289}