reth_trie_parallel/
proof_task.rs

1//! A Task that manages sending proof requests to a number of tasks that have longer-running
2//! database transactions.
3//!
4//! The [`ProofTaskManager`] ensures that there are a max number of currently executing proof tasks,
5//! and is responsible for managing the fixed number of database transactions created at the start
6//! of the task.
7//!
8//! Individual [`ProofTaskTx`] instances manage a dedicated [`InMemoryTrieCursorFactory`] and
9//! [`HashedPostStateCursorFactory`], which are each backed by a database transaction.
10
11use crate::root::ParallelStateRootError;
12use alloy_primitives::{map::B256Set, B256};
13use reth_db_api::transaction::DbTx;
14use reth_execution_errors::SparseTrieError;
15use reth_provider::{
16    providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
17    ProviderResult,
18};
19use reth_trie::{
20    hashed_cursor::HashedPostStateCursorFactory,
21    prefix_set::TriePrefixSetsMut,
22    proof::{ProofTrieNodeProviderFactory, StorageProof},
23    trie_cursor::InMemoryTrieCursorFactory,
24    updates::TrieUpdatesSorted,
25    DecodedStorageMultiProof, HashedPostStateSorted, Nibbles,
26};
27use reth_trie_common::{
28    added_removed_keys::MultiAddedRemovedKeys,
29    prefix_set::{PrefixSet, PrefixSetMut},
30};
31use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
32use reth_trie_sparse::provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory};
33use std::{
34    collections::VecDeque,
35    sync::{
36        atomic::{AtomicUsize, Ordering},
37        mpsc::{channel, Receiver, SendError, Sender},
38        Arc,
39    },
40    time::Instant,
41};
42use tokio::runtime::Handle;
43use tracing::debug;
44
45#[cfg(feature = "metrics")]
46use crate::proof_task_metrics::ProofTaskMetrics;
47
48type StorageProofResult = Result<DecodedStorageMultiProof, ParallelStateRootError>;
49type TrieNodeProviderResult = Result<Option<RevealedNode>, SparseTrieError>;
50
51/// A task that manages sending multiproof requests to a number of tasks that have longer-running
52/// database transactions
53#[derive(Debug)]
54pub struct ProofTaskManager<Factory: DatabaseProviderFactory> {
55    /// Max number of database transactions to create
56    max_concurrency: usize,
57    /// Number of database transactions created
58    total_transactions: usize,
59    /// Consistent view provider used for creating transactions on-demand
60    view: ConsistentDbView<Factory>,
61    /// Proof task context shared across all proof tasks
62    task_ctx: ProofTaskCtx,
63    /// Proof tasks pending execution
64    pending_tasks: VecDeque<ProofTaskKind>,
65    /// The underlying handle from which to spawn proof tasks
66    executor: Handle,
67    /// The proof task transactions, containing owned cursor factories that are reused for proof
68    /// calculation.
69    proof_task_txs: Vec<ProofTaskTx<FactoryTx<Factory>>>,
70    /// A receiver for new proof tasks.
71    proof_task_rx: Receiver<ProofTaskMessage<FactoryTx<Factory>>>,
72    /// A sender for sending back transactions.
73    tx_sender: Sender<ProofTaskMessage<FactoryTx<Factory>>>,
74    /// The number of active handles.
75    ///
76    /// Incremented in [`ProofTaskManagerHandle::new`] and decremented in
77    /// [`ProofTaskManagerHandle::drop`].
78    active_handles: Arc<AtomicUsize>,
79    /// Metrics tracking blinded node fetches.
80    #[cfg(feature = "metrics")]
81    metrics: ProofTaskMetrics,
82}
83
84impl<Factory: DatabaseProviderFactory> ProofTaskManager<Factory> {
85    /// Creates a new [`ProofTaskManager`] with the given max concurrency, creating that number of
86    /// cursor factories.
87    ///
88    /// Returns an error if the consistent view provider fails to create a read-only transaction.
89    pub fn new(
90        executor: Handle,
91        view: ConsistentDbView<Factory>,
92        task_ctx: ProofTaskCtx,
93        max_concurrency: usize,
94    ) -> Self {
95        let (tx_sender, proof_task_rx) = channel();
96        Self {
97            max_concurrency,
98            total_transactions: 0,
99            view,
100            task_ctx,
101            pending_tasks: VecDeque::new(),
102            executor,
103            proof_task_txs: Vec::new(),
104            proof_task_rx,
105            tx_sender,
106            active_handles: Arc::new(AtomicUsize::new(0)),
107            #[cfg(feature = "metrics")]
108            metrics: ProofTaskMetrics::default(),
109        }
110    }
111
112    /// Returns a handle for sending new proof tasks to the [`ProofTaskManager`].
113    pub fn handle(&self) -> ProofTaskManagerHandle<FactoryTx<Factory>> {
114        ProofTaskManagerHandle::new(self.tx_sender.clone(), self.active_handles.clone())
115    }
116}
117
118impl<Factory> ProofTaskManager<Factory>
119where
120    Factory: DatabaseProviderFactory<Provider: BlockReader> + 'static,
121{
122    /// Inserts the task into the pending tasks queue.
123    pub fn queue_proof_task(&mut self, task: ProofTaskKind) {
124        self.pending_tasks.push_back(task);
125    }
126
127    /// Gets either the next available transaction, or creates a new one if all are in use and the
128    /// total number of transactions created is less than the max concurrency.
129    pub fn get_or_create_tx(&mut self) -> ProviderResult<Option<ProofTaskTx<FactoryTx<Factory>>>> {
130        if let Some(proof_task_tx) = self.proof_task_txs.pop() {
131            return Ok(Some(proof_task_tx));
132        }
133
134        // if we can create a new tx within our concurrency limits, create one on-demand
135        if self.total_transactions < self.max_concurrency {
136            let provider_ro = self.view.provider_ro()?;
137            let tx = provider_ro.into_tx();
138            self.total_transactions += 1;
139            return Ok(Some(ProofTaskTx::new(tx, self.task_ctx.clone(), self.total_transactions)));
140        }
141
142        Ok(None)
143    }
144
145    /// Spawns the next queued proof task on the executor with the given input, if there are any
146    /// transactions available.
147    ///
148    /// This will return an error if a transaction must be created on-demand and the consistent view
149    /// provider fails.
150    pub fn try_spawn_next(&mut self) -> ProviderResult<()> {
151        let Some(task) = self.pending_tasks.pop_front() else { return Ok(()) };
152
153        let Some(proof_task_tx) = self.get_or_create_tx()? else {
154            // if there are no txs available, requeue the proof task
155            self.pending_tasks.push_front(task);
156            return Ok(())
157        };
158
159        let tx_sender = self.tx_sender.clone();
160        self.executor.spawn_blocking(move || match task {
161            ProofTaskKind::StorageProof(input, sender) => {
162                proof_task_tx.storage_proof(input, sender, tx_sender);
163            }
164            ProofTaskKind::BlindedAccountNode(path, sender) => {
165                proof_task_tx.blinded_account_node(path, sender, tx_sender);
166            }
167            ProofTaskKind::BlindedStorageNode(account, path, sender) => {
168                proof_task_tx.blinded_storage_node(account, path, sender, tx_sender);
169            }
170        });
171
172        Ok(())
173    }
174
175    /// Loops, managing the proof tasks, and sending new tasks to the executor.
176    pub fn run(mut self) -> ProviderResult<()> {
177        loop {
178            match self.proof_task_rx.recv() {
179                Ok(message) => match message {
180                    ProofTaskMessage::QueueTask(task) => {
181                        // Track metrics for blinded node requests
182                        #[cfg(feature = "metrics")]
183                        match &task {
184                            ProofTaskKind::BlindedAccountNode(_, _) => {
185                                self.metrics.account_nodes += 1;
186                            }
187                            ProofTaskKind::BlindedStorageNode(_, _, _) => {
188                                self.metrics.storage_nodes += 1;
189                            }
190                            _ => {}
191                        }
192                        // queue the task
193                        self.queue_proof_task(task)
194                    }
195                    ProofTaskMessage::Transaction(tx) => {
196                        // return the transaction to the pool
197                        self.proof_task_txs.push(tx);
198                    }
199                    ProofTaskMessage::Terminate => {
200                        // Record metrics before terminating
201                        #[cfg(feature = "metrics")]
202                        self.metrics.record();
203                        return Ok(())
204                    }
205                },
206                // All senders are disconnected, so we can terminate
207                // However this should never happen, as this struct stores a sender
208                Err(_) => return Ok(()),
209            };
210
211            // try spawning the next task
212            self.try_spawn_next()?;
213        }
214    }
215}
216
217/// This contains all information shared between all storage proof instances.
218#[derive(Debug)]
219pub struct ProofTaskTx<Tx> {
220    /// The tx that is reused for proof calculations.
221    tx: Tx,
222
223    /// Trie updates, prefix sets, and state updates
224    task_ctx: ProofTaskCtx,
225
226    /// Identifier for the tx within the context of a single [`ProofTaskManager`], used only for
227    /// tracing.
228    id: usize,
229}
230
231impl<Tx> ProofTaskTx<Tx> {
232    /// Initializes a [`ProofTaskTx`] using the given transaction and a [`ProofTaskCtx`]. The id is
233    /// used only for tracing.
234    const fn new(tx: Tx, task_ctx: ProofTaskCtx, id: usize) -> Self {
235        Self { tx, task_ctx, id }
236    }
237}
238
239impl<Tx> ProofTaskTx<Tx>
240where
241    Tx: DbTx,
242{
243    fn create_factories(
244        &self,
245    ) -> (
246        InMemoryTrieCursorFactory<'_, DatabaseTrieCursorFactory<'_, Tx>>,
247        HashedPostStateCursorFactory<'_, DatabaseHashedCursorFactory<'_, Tx>>,
248    ) {
249        let trie_cursor_factory = InMemoryTrieCursorFactory::new(
250            DatabaseTrieCursorFactory::new(&self.tx),
251            &self.task_ctx.nodes_sorted,
252        );
253
254        let hashed_cursor_factory = HashedPostStateCursorFactory::new(
255            DatabaseHashedCursorFactory::new(&self.tx),
256            &self.task_ctx.state_sorted,
257        );
258
259        (trie_cursor_factory, hashed_cursor_factory)
260    }
261
262    /// Calculates a storage proof for the given hashed address, and desired prefix set.
263    fn storage_proof(
264        self,
265        input: StorageProofInput,
266        result_sender: Sender<StorageProofResult>,
267        tx_sender: Sender<ProofTaskMessage<Tx>>,
268    ) {
269        debug!(
270            target: "trie::proof_task",
271            hashed_address=?input.hashed_address,
272            "Starting storage proof task calculation"
273        );
274
275        let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
276        let multi_added_removed_keys = input
277            .multi_added_removed_keys
278            .unwrap_or_else(|| Arc::new(MultiAddedRemovedKeys::new()));
279        let added_removed_keys = multi_added_removed_keys.get_storage(&input.hashed_address);
280
281        let span = tracing::trace_span!(
282            target: "trie::proof_task",
283            "Storage proof calculation",
284            hashed_address=?input.hashed_address,
285            // Add a unique id because we often have parallel storage proof calculations for the
286            // same hashed address, and we want to differentiate them during trace analysis.
287            span_id=self.id,
288        );
289        let span_guard = span.enter();
290
291        let target_slots_len = input.target_slots.len();
292        let proof_start = Instant::now();
293
294        let raw_proof_result = StorageProof::new_hashed(
295            trie_cursor_factory,
296            hashed_cursor_factory,
297            input.hashed_address,
298        )
299        .with_prefix_set_mut(PrefixSetMut::from(input.prefix_set.iter().copied()))
300        .with_branch_node_masks(input.with_branch_node_masks)
301        .with_added_removed_keys(added_removed_keys)
302        .storage_multiproof(input.target_slots)
303        .map_err(|e| ParallelStateRootError::Other(e.to_string()));
304
305        drop(span_guard);
306
307        let decoded_result = raw_proof_result.and_then(|raw_proof| {
308            raw_proof.try_into().map_err(|e: alloy_rlp::Error| {
309                ParallelStateRootError::Other(format!(
310                    "Failed to decode storage proof for {}: {}",
311                    input.hashed_address, e
312                ))
313            })
314        });
315
316        debug!(
317            target: "trie::proof_task",
318            hashed_address=?input.hashed_address,
319            prefix_set = ?input.prefix_set.len(),
320            target_slots = ?target_slots_len,
321            proof_time = ?proof_start.elapsed(),
322            "Completed storage proof task calculation"
323        );
324
325        // send the result back
326        if let Err(error) = result_sender.send(decoded_result) {
327            debug!(
328                target: "trie::proof_task",
329                hashed_address = ?input.hashed_address,
330                ?error,
331                task_time = ?proof_start.elapsed(),
332                "Storage proof receiver is dropped, discarding the result"
333            );
334        }
335
336        // send the tx back
337        let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
338    }
339
340    /// Retrieves blinded account node by path.
341    fn blinded_account_node(
342        self,
343        path: Nibbles,
344        result_sender: Sender<TrieNodeProviderResult>,
345        tx_sender: Sender<ProofTaskMessage<Tx>>,
346    ) {
347        debug!(
348            target: "trie::proof_task",
349            ?path,
350            "Starting blinded account node retrieval"
351        );
352
353        let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
354
355        let blinded_provider_factory = ProofTrieNodeProviderFactory::new(
356            trie_cursor_factory,
357            hashed_cursor_factory,
358            self.task_ctx.prefix_sets.clone(),
359        );
360
361        let start = Instant::now();
362        let result = blinded_provider_factory.account_node_provider().trie_node(&path);
363        debug!(
364            target: "trie::proof_task",
365            ?path,
366            elapsed = ?start.elapsed(),
367            "Completed blinded account node retrieval"
368        );
369
370        if let Err(error) = result_sender.send(result) {
371            tracing::error!(
372                target: "trie::proof_task",
373                ?path,
374                ?error,
375                "Failed to send blinded account node result"
376            );
377        }
378
379        // send the tx back
380        let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
381    }
382
383    /// Retrieves blinded storage node of the given account by path.
384    fn blinded_storage_node(
385        self,
386        account: B256,
387        path: Nibbles,
388        result_sender: Sender<TrieNodeProviderResult>,
389        tx_sender: Sender<ProofTaskMessage<Tx>>,
390    ) {
391        debug!(
392            target: "trie::proof_task",
393            ?account,
394            ?path,
395            "Starting blinded storage node retrieval"
396        );
397
398        let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
399
400        let blinded_provider_factory = ProofTrieNodeProviderFactory::new(
401            trie_cursor_factory,
402            hashed_cursor_factory,
403            self.task_ctx.prefix_sets.clone(),
404        );
405
406        let start = Instant::now();
407        let result = blinded_provider_factory.storage_node_provider(account).trie_node(&path);
408        debug!(
409            target: "trie::proof_task",
410            ?account,
411            ?path,
412            elapsed = ?start.elapsed(),
413            "Completed blinded storage node retrieval"
414        );
415
416        if let Err(error) = result_sender.send(result) {
417            tracing::error!(
418                target: "trie::proof_task",
419                ?account,
420                ?path,
421                ?error,
422                "Failed to send blinded storage node result"
423            );
424        }
425
426        // send the tx back
427        let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
428    }
429}
430
431/// This represents an input for a storage proof.
432#[derive(Debug)]
433pub struct StorageProofInput {
434    /// The hashed address for which the proof is calculated.
435    hashed_address: B256,
436    /// The prefix set for the proof calculation.
437    prefix_set: PrefixSet,
438    /// The target slots for the proof calculation.
439    target_slots: B256Set,
440    /// Whether or not to collect branch node masks
441    with_branch_node_masks: bool,
442    /// Provided by the user to give the necessary context to retain extra proofs.
443    multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
444}
445
446impl StorageProofInput {
447    /// Creates a new [`StorageProofInput`] with the given hashed address, prefix set, and target
448    /// slots.
449    pub const fn new(
450        hashed_address: B256,
451        prefix_set: PrefixSet,
452        target_slots: B256Set,
453        with_branch_node_masks: bool,
454        multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
455    ) -> Self {
456        Self {
457            hashed_address,
458            prefix_set,
459            target_slots,
460            with_branch_node_masks,
461            multi_added_removed_keys,
462        }
463    }
464}
465
466/// Data used for initializing cursor factories that is shared across all storage proof instances.
467#[derive(Debug, Clone)]
468pub struct ProofTaskCtx {
469    /// The sorted collection of cached in-memory intermediate trie nodes that can be reused for
470    /// computation.
471    nodes_sorted: Arc<TrieUpdatesSorted>,
472    /// The sorted in-memory overlay hashed state.
473    state_sorted: Arc<HashedPostStateSorted>,
474    /// The collection of prefix sets for the computation. Since the prefix sets _always_
475    /// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
476    /// if we have cached nodes for them.
477    prefix_sets: Arc<TriePrefixSetsMut>,
478}
479
480impl ProofTaskCtx {
481    /// Creates a new [`ProofTaskCtx`] with the given sorted nodes and state.
482    pub const fn new(
483        nodes_sorted: Arc<TrieUpdatesSorted>,
484        state_sorted: Arc<HashedPostStateSorted>,
485        prefix_sets: Arc<TriePrefixSetsMut>,
486    ) -> Self {
487        Self { nodes_sorted, state_sorted, prefix_sets }
488    }
489}
490
491/// Message used to communicate with [`ProofTaskManager`].
492#[derive(Debug)]
493pub enum ProofTaskMessage<Tx> {
494    /// A request to queue a proof task.
495    QueueTask(ProofTaskKind),
496    /// A returned database transaction.
497    Transaction(ProofTaskTx<Tx>),
498    /// A request to terminate the proof task manager.
499    Terminate,
500}
501
502/// Proof task kind.
503///
504/// When queueing a task using [`ProofTaskMessage::QueueTask`], this enum
505/// specifies the type of proof task to be executed.
506#[derive(Debug)]
507pub enum ProofTaskKind {
508    /// A storage proof request.
509    StorageProof(StorageProofInput, Sender<StorageProofResult>),
510    /// A blinded account node request.
511    BlindedAccountNode(Nibbles, Sender<TrieNodeProviderResult>),
512    /// A blinded storage node request.
513    BlindedStorageNode(B256, Nibbles, Sender<TrieNodeProviderResult>),
514}
515
516/// A handle that wraps a single proof task sender that sends a terminate message on `Drop` if the
517/// number of active handles went to zero.
518#[derive(Debug)]
519pub struct ProofTaskManagerHandle<Tx> {
520    /// The sender for the proof task manager.
521    sender: Sender<ProofTaskMessage<Tx>>,
522    /// The number of active handles.
523    active_handles: Arc<AtomicUsize>,
524}
525
526impl<Tx> ProofTaskManagerHandle<Tx> {
527    /// Creates a new [`ProofTaskManagerHandle`] with the given sender.
528    pub fn new(sender: Sender<ProofTaskMessage<Tx>>, active_handles: Arc<AtomicUsize>) -> Self {
529        active_handles.fetch_add(1, Ordering::SeqCst);
530        Self { sender, active_handles }
531    }
532
533    /// Queues a task to the proof task manager.
534    pub fn queue_task(&self, task: ProofTaskKind) -> Result<(), SendError<ProofTaskMessage<Tx>>> {
535        self.sender.send(ProofTaskMessage::QueueTask(task))
536    }
537
538    /// Terminates the proof task manager.
539    pub fn terminate(&self) {
540        let _ = self.sender.send(ProofTaskMessage::Terminate);
541    }
542}
543
544impl<Tx> Clone for ProofTaskManagerHandle<Tx> {
545    fn clone(&self) -> Self {
546        Self::new(self.sender.clone(), self.active_handles.clone())
547    }
548}
549
550impl<Tx> Drop for ProofTaskManagerHandle<Tx> {
551    fn drop(&mut self) {
552        // Decrement the number of active handles and terminate the manager if it was the last
553        // handle.
554        if self.active_handles.fetch_sub(1, Ordering::SeqCst) == 1 {
555            self.terminate();
556        }
557    }
558}
559
560impl<Tx: DbTx> TrieNodeProviderFactory for ProofTaskManagerHandle<Tx> {
561    type AccountNodeProvider = ProofTaskTrieNodeProvider<Tx>;
562    type StorageNodeProvider = ProofTaskTrieNodeProvider<Tx>;
563
564    fn account_node_provider(&self) -> Self::AccountNodeProvider {
565        ProofTaskTrieNodeProvider::AccountNode { sender: self.sender.clone() }
566    }
567
568    fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
569        ProofTaskTrieNodeProvider::StorageNode { account, sender: self.sender.clone() }
570    }
571}
572
573/// Trie node provider for retrieving trie nodes by path.
574#[derive(Debug)]
575pub enum ProofTaskTrieNodeProvider<Tx> {
576    /// Blinded account trie node provider.
577    AccountNode {
578        /// Sender to the proof task.
579        sender: Sender<ProofTaskMessage<Tx>>,
580    },
581    /// Blinded storage trie node provider.
582    StorageNode {
583        /// Target account.
584        account: B256,
585        /// Sender to the proof task.
586        sender: Sender<ProofTaskMessage<Tx>>,
587    },
588}
589
590impl<Tx: DbTx> TrieNodeProvider for ProofTaskTrieNodeProvider<Tx> {
591    fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
592        let (tx, rx) = channel();
593        match self {
594            Self::AccountNode { sender } => {
595                let _ = sender.send(ProofTaskMessage::QueueTask(
596                    ProofTaskKind::BlindedAccountNode(*path, tx),
597                ));
598            }
599            Self::StorageNode { sender, account } => {
600                let _ = sender.send(ProofTaskMessage::QueueTask(
601                    ProofTaskKind::BlindedStorageNode(*account, *path, tx),
602                ));
603            }
604        }
605
606        rx.recv().unwrap()
607    }
608}