reth_trie_parallel/
proof.rs

1use crate::{
2    metrics::ParallelTrieMetrics, root::ParallelStateRootError, stats::ParallelTrieTracker,
3    StorageRootTargets,
4};
5use alloy_primitives::{
6    map::{B256Map, HashMap},
7    B256,
8};
9use alloy_rlp::{BufMut, Encodable};
10use itertools::Itertools;
11use reth_execution_errors::StorageRootError;
12use reth_provider::{
13    providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, ProviderError,
14    StateCommitmentProvider,
15};
16use reth_storage_errors::db::DatabaseError;
17use reth_trie::{
18    hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
19    node_iter::{TrieElement, TrieNodeIter},
20    prefix_set::{PrefixSetMut, TriePrefixSetsMut},
21    proof::StorageProof,
22    trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
23    updates::TrieUpdatesSorted,
24    walker::TrieWalker,
25    HashBuilder, HashedPostStateSorted, MultiProof, MultiProofTargets, Nibbles, StorageMultiProof,
26    TRIE_ACCOUNT_RLP_MAX_SIZE,
27};
28use reth_trie_common::proof::ProofRetainer;
29use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
30use std::{sync::Arc, time::Instant};
31use tokio::runtime::Handle;
32use tracing::{debug, trace};
33
34/// Parallel proof calculator.
35///
36/// This can collect proof for many targets in parallel, spawning a task for each hashed address
37/// that has proof targets.
38#[derive(Debug)]
39pub struct ParallelProof<Factory> {
40    /// Consistent view of the database.
41    view: ConsistentDbView<Factory>,
42    /// The sorted collection of cached in-memory intermediate trie nodes that
43    /// can be reused for computation.
44    pub nodes_sorted: Arc<TrieUpdatesSorted>,
45    /// The sorted in-memory overlay hashed state.
46    pub state_sorted: Arc<HashedPostStateSorted>,
47    /// The collection of prefix sets for the computation. Since the prefix sets _always_
48    /// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
49    /// if we have cached nodes for them.
50    pub prefix_sets: Arc<TriePrefixSetsMut>,
51    /// Flag indicating whether to include branch node masks in the proof.
52    collect_branch_node_masks: bool,
53    /// Handle to spawn blocking tasks to fetch data.
54    executor: Handle,
55    #[cfg(feature = "metrics")]
56    metrics: ParallelTrieMetrics,
57}
58
59impl<Factory> ParallelProof<Factory> {
60    /// Create new state proof generator.
61    pub fn new(
62        view: ConsistentDbView<Factory>,
63        nodes_sorted: Arc<TrieUpdatesSorted>,
64        state_sorted: Arc<HashedPostStateSorted>,
65        prefix_sets: Arc<TriePrefixSetsMut>,
66        executor: Handle,
67    ) -> Self {
68        Self {
69            view,
70            nodes_sorted,
71            state_sorted,
72            prefix_sets,
73            collect_branch_node_masks: false,
74            executor,
75            #[cfg(feature = "metrics")]
76            metrics: ParallelTrieMetrics::new_with_labels(&[("type", "proof")]),
77        }
78    }
79
80    /// Set the flag indicating whether to include branch node masks in the proof.
81    pub const fn with_branch_node_masks(mut self, branch_node_masks: bool) -> Self {
82        self.collect_branch_node_masks = branch_node_masks;
83        self
84    }
85}
86
87impl<Factory> ParallelProof<Factory>
88where
89    Factory:
90        DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider + Clone + 'static,
91{
92    /// Generate a state multiproof according to specified targets.
93    pub fn multiproof(
94        self,
95        targets: MultiProofTargets,
96    ) -> Result<MultiProof, ParallelStateRootError> {
97        let mut tracker = ParallelTrieTracker::default();
98
99        // Extend prefix sets with targets
100        let mut prefix_sets = (*self.prefix_sets).clone();
101        prefix_sets.extend(TriePrefixSetsMut {
102            account_prefix_set: PrefixSetMut::from(targets.keys().copied().map(Nibbles::unpack)),
103            storage_prefix_sets: targets
104                .iter()
105                .filter(|&(_hashed_address, slots)| !slots.is_empty())
106                .map(|(hashed_address, slots)| {
107                    (*hashed_address, PrefixSetMut::from(slots.iter().map(Nibbles::unpack)))
108                })
109                .collect(),
110            destroyed_accounts: Default::default(),
111        });
112        let prefix_sets = prefix_sets.freeze();
113
114        let storage_root_targets = StorageRootTargets::new(
115            prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
116            prefix_sets.storage_prefix_sets.clone(),
117        );
118        let storage_root_targets_len = storage_root_targets.len();
119
120        debug!(
121            target: "trie::parallel_proof",
122            total_targets = storage_root_targets_len,
123            "Starting parallel proof generation"
124        );
125
126        // Pre-calculate storage roots for accounts which were changed.
127        tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
128
129        // stores the receiver for the storage proof outcome for the hashed addresses
130        // this way we can lazily await the outcome when we iterate over the map
131        let mut storage_proofs =
132            B256Map::with_capacity_and_hasher(storage_root_targets.len(), Default::default());
133
134        for (hashed_address, prefix_set) in
135            storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
136        {
137            let view = self.view.clone();
138            let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();
139            let trie_nodes_sorted = self.nodes_sorted.clone();
140            let hashed_state_sorted = self.state_sorted.clone();
141            let collect_masks = self.collect_branch_node_masks;
142
143            let (tx, rx) = std::sync::mpsc::sync_channel(1);
144
145            // spawn the task as blocking and send the the result through the channel
146            self.executor.spawn_blocking(move || {
147                debug!(
148                    target: "trie::parallel_proof",
149                    ?hashed_address,
150                    "Starting proof calculation"
151                );
152
153                let task_start = Instant::now();
154                let result = (|| -> Result<_, ParallelStateRootError> {
155                    let provider_start = Instant::now();
156                    let provider_ro = view.provider_ro()?;
157                    trace!(
158                        target: "trie::parallel_proof",
159                        ?hashed_address,
160                        provider_time = ?provider_start.elapsed(),
161                        "Got provider"
162                    );
163
164                    let cursor_start = Instant::now();
165                    let trie_cursor_factory = InMemoryTrieCursorFactory::new(
166                        DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
167                        &trie_nodes_sorted,
168                    );
169                    let hashed_cursor_factory = HashedPostStateCursorFactory::new(
170                        DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
171                        &hashed_state_sorted,
172                    );
173                    trace!(
174                        target: "trie::parallel_proof",
175                        ?hashed_address,
176                        cursor_time = ?cursor_start.elapsed(),
177                        "Created cursors"
178                    );
179
180                    let target_slots_len = target_slots.len();
181                    let proof_start = Instant::now();
182                    let proof_result = StorageProof::new_hashed(
183                        trie_cursor_factory,
184                        hashed_cursor_factory,
185                        hashed_address,
186                    )
187                    .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
188                    .with_branch_node_masks(collect_masks)
189                    .storage_multiproof(target_slots)
190                    .map_err(|e| ParallelStateRootError::Other(e.to_string()));
191
192                    trace!(
193                        target: "trie::parallel_proof",
194                        ?hashed_address,
195                        prefix_set = ?prefix_set.len(),
196                        target_slots = ?target_slots_len,
197                        proof_time = ?proof_start.elapsed(),
198                        "Completed proof calculation"
199                    );
200
201                    proof_result
202                })();
203
204                // We can have the receiver dropped before we send, because we still calculate
205                // storage proofs for deleted accounts, but do not actually walk over them in
206                // `account_node_iter` below.
207                if let Err(e) = tx.send(result) {
208                    debug!(
209                        target: "trie::parallel_proof",
210                        ?hashed_address,
211                        error = ?e,
212                        task_time = ?task_start.elapsed(),
213                        "Failed to send proof result"
214                    );
215                }
216            });
217
218            // store the receiver for that result with the hashed address so we can await this in
219            // place when we iterate over the trie
220            storage_proofs.insert(hashed_address, rx);
221        }
222
223        let provider_ro = self.view.provider_ro()?;
224        let trie_cursor_factory = InMemoryTrieCursorFactory::new(
225            DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
226            &self.nodes_sorted,
227        );
228        let hashed_cursor_factory = HashedPostStateCursorFactory::new(
229            DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
230            &self.state_sorted,
231        );
232
233        // Create the walker.
234        let walker = TrieWalker::new(
235            trie_cursor_factory.account_trie_cursor().map_err(ProviderError::Database)?,
236            prefix_sets.account_prefix_set,
237        )
238        .with_deletions_retained(true);
239
240        // Create a hash builder to rebuild the root node since it is not available in the database.
241        let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
242        let mut hash_builder = HashBuilder::default()
243            .with_proof_retainer(retainer)
244            .with_updates(self.collect_branch_node_masks);
245
246        // Initialize all storage multiproofs as empty.
247        // Storage multiproofs for non empty tries will be overwritten if necessary.
248        let mut storages: B256Map<_> =
249            targets.keys().map(|key| (*key, StorageMultiProof::empty())).collect();
250        let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
251        let mut account_node_iter = TrieNodeIter::new(
252            walker,
253            hashed_cursor_factory.hashed_account_cursor().map_err(ProviderError::Database)?,
254        );
255        while let Some(account_node) =
256            account_node_iter.try_next().map_err(ProviderError::Database)?
257        {
258            match account_node {
259                TrieElement::Branch(node) => {
260                    hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
261                }
262                TrieElement::Leaf(hashed_address, account) => {
263                    let storage_multiproof = match storage_proofs.remove(&hashed_address) {
264                        Some(rx) => rx.recv().map_err(|_| {
265                            ParallelStateRootError::StorageRoot(StorageRootError::Database(
266                                DatabaseError::Other(format!(
267                                    "channel closed for {hashed_address}"
268                                )),
269                            ))
270                        })??,
271                        // Since we do not store all intermediate nodes in the database, there might
272                        // be a possibility of re-adding a non-modified leaf to the hash builder.
273                        None => {
274                            tracker.inc_missed_leaves();
275                            StorageProof::new_hashed(
276                                trie_cursor_factory.clone(),
277                                hashed_cursor_factory.clone(),
278                                hashed_address,
279                            )
280                            .with_prefix_set_mut(Default::default())
281                            .storage_multiproof(
282                                targets.get(&hashed_address).cloned().unwrap_or_default(),
283                            )
284                            .map_err(|e| {
285                                ParallelStateRootError::StorageRoot(StorageRootError::Database(
286                                    DatabaseError::Other(e.to_string()),
287                                ))
288                            })?
289                        }
290                    };
291
292                    // Encode account
293                    account_rlp.clear();
294                    let account = account.into_trie_account(storage_multiproof.root);
295                    account.encode(&mut account_rlp as &mut dyn BufMut);
296
297                    hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
298
299                    // We might be adding leaves that are not necessarily our proof targets.
300                    if targets.contains_key(&hashed_address) {
301                        storages.insert(hashed_address, storage_multiproof);
302                    }
303                }
304            }
305        }
306        let _ = hash_builder.root();
307
308        let stats = tracker.finish();
309        #[cfg(feature = "metrics")]
310        self.metrics.record(stats);
311
312        let account_subtree = hash_builder.take_proof_nodes();
313        let (branch_node_hash_masks, branch_node_tree_masks) = if self.collect_branch_node_masks {
314            let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
315            (
316                updated_branch_nodes
317                    .iter()
318                    .map(|(path, node)| (path.clone(), node.hash_mask))
319                    .collect(),
320                updated_branch_nodes
321                    .into_iter()
322                    .map(|(path, node)| (path, node.tree_mask))
323                    .collect(),
324            )
325        } else {
326            (HashMap::default(), HashMap::default())
327        };
328
329        debug!(
330            target: "trie::parallel_proof",
331            total_targets = storage_root_targets_len,
332            duration = ?stats.duration(),
333            branches_added = stats.branches_added(),
334            leaves_added = stats.leaves_added(),
335            missed_leaves = stats.missed_leaves(),
336            precomputed_storage_roots = stats.precomputed_storage_roots(),
337            "Calculated proof"
338        );
339
340        Ok(MultiProof { account_subtree, branch_node_hash_masks, branch_node_tree_masks, storages })
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use alloy_primitives::{
348        keccak256,
349        map::{B256Set, DefaultHashBuilder},
350        Address, U256,
351    };
352    use rand::Rng;
353    use reth_primitives::{Account, StorageEntry};
354    use reth_provider::{test_utils::create_test_provider_factory, HashingWriter};
355    use reth_trie::proof::Proof;
356    use tokio::runtime::Runtime;
357
358    #[test]
359    fn random_parallel_proof() {
360        let factory = create_test_provider_factory();
361        let consistent_view = ConsistentDbView::new(factory.clone(), None);
362
363        let mut rng = rand::thread_rng();
364        let state = (0..100)
365            .map(|_| {
366                let address = Address::random();
367                let account =
368                    Account { balance: U256::from(rng.gen::<u64>()), ..Default::default() };
369                let mut storage = HashMap::<B256, U256, DefaultHashBuilder>::default();
370                let has_storage = rng.gen_bool(0.7);
371                if has_storage {
372                    for _ in 0..100 {
373                        storage.insert(
374                            B256::from(U256::from(rng.gen::<u64>())),
375                            U256::from(rng.gen::<u64>()),
376                        );
377                    }
378                }
379                (address, (account, storage))
380            })
381            .collect::<HashMap<_, _, DefaultHashBuilder>>();
382
383        {
384            let provider_rw = factory.provider_rw().unwrap();
385            provider_rw
386                .insert_account_for_hashing(
387                    state.iter().map(|(address, (account, _))| (*address, Some(*account))),
388                )
389                .unwrap();
390            provider_rw
391                .insert_storage_for_hashing(state.iter().map(|(address, (_, storage))| {
392                    (
393                        *address,
394                        storage
395                            .iter()
396                            .map(|(slot, value)| StorageEntry { key: *slot, value: *value }),
397                    )
398                }))
399                .unwrap();
400            provider_rw.commit().unwrap();
401        }
402
403        let mut targets = MultiProofTargets::default();
404        for (address, (_, storage)) in state.iter().take(10) {
405            let hashed_address = keccak256(*address);
406            let mut target_slots = B256Set::default();
407
408            for (slot, _) in storage.iter().take(5) {
409                target_slots.insert(*slot);
410            }
411
412            if !target_slots.is_empty() {
413                targets.insert(hashed_address, target_slots);
414            }
415        }
416
417        let provider_rw = factory.provider_rw().unwrap();
418        let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
419        let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());
420
421        let rt = Runtime::new().unwrap();
422
423        assert_eq!(
424            ParallelProof::new(
425                consistent_view,
426                Default::default(),
427                Default::default(),
428                Default::default(),
429                rt.handle().clone()
430            )
431            .multiproof(targets.clone())
432            .unwrap(),
433            Proof::new(trie_cursor_factory, hashed_cursor_factory).multiproof(targets).unwrap()
434        );
435    }
436}