reth_trie_parallel/
proof_task.rs

1//! A Task that manages sending proof requests to a number of tasks that have longer-running
2//! database transactions.
3//!
4//! The [`ProofTaskManager`] ensures that there are a max number of currently executing proof tasks,
5//! and is responsible for managing the fixed number of database transactions created at the start
6//! of the task.
7//!
8//! Individual [`ProofTaskTx`] instances manage a dedicated [`InMemoryTrieCursorFactory`] and
9//! [`HashedPostStateCursorFactory`], which are each backed by a database transaction.
10
11use crate::root::ParallelStateRootError;
12use alloy_primitives::{map::B256Set, B256};
13use reth_db_api::transaction::DbTx;
14use reth_execution_errors::SparseTrieError;
15use reth_provider::{
16    providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
17    ProviderResult, StateCommitmentProvider,
18};
19use reth_trie::{
20    hashed_cursor::HashedPostStateCursorFactory,
21    prefix_set::TriePrefixSetsMut,
22    proof::{ProofBlindedProviderFactory, StorageProof},
23    trie_cursor::InMemoryTrieCursorFactory,
24    updates::TrieUpdatesSorted,
25    HashedPostStateSorted, Nibbles, StorageMultiProof,
26};
27use reth_trie_common::prefix_set::{PrefixSet, PrefixSetMut};
28use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
29use reth_trie_sparse::blinded::{BlindedProvider, BlindedProviderFactory, RevealedNode};
30use std::{
31    collections::VecDeque,
32    sync::{
33        atomic::{AtomicUsize, Ordering},
34        mpsc::{channel, Receiver, SendError, Sender},
35        Arc,
36    },
37    time::Instant,
38};
39use tokio::runtime::Handle;
40use tracing::debug;
41
42type StorageProofResult = Result<StorageMultiProof, ParallelStateRootError>;
43type BlindedNodeResult = Result<Option<RevealedNode>, SparseTrieError>;
44
45/// A task that manages sending multiproof requests to a number of tasks that have longer-running
46/// database transactions
47#[derive(Debug)]
48pub struct ProofTaskManager<Factory: DatabaseProviderFactory> {
49    /// Max number of database transactions to create
50    max_concurrency: usize,
51    /// Number of database transactions created
52    total_transactions: usize,
53    /// Consistent view provider used for creating transactions on-demand
54    view: ConsistentDbView<Factory>,
55    /// Proof task context shared across all proof tasks
56    task_ctx: ProofTaskCtx,
57    /// Proof tasks pending execution
58    pending_tasks: VecDeque<ProofTaskKind>,
59    /// The underlying handle from which to spawn proof tasks
60    executor: Handle,
61    /// The proof task transactions, containing owned cursor factories that are reused for proof
62    /// calculation.
63    proof_task_txs: Vec<ProofTaskTx<FactoryTx<Factory>>>,
64    /// A receiver for new proof tasks.
65    proof_task_rx: Receiver<ProofTaskMessage<FactoryTx<Factory>>>,
66    /// A sender for sending back transactions.
67    tx_sender: Sender<ProofTaskMessage<FactoryTx<Factory>>>,
68    /// The number of active handles.
69    ///
70    /// Incremented in [`ProofTaskManagerHandle::new`] and decremented in
71    /// [`ProofTaskManagerHandle::drop`].
72    active_handles: Arc<AtomicUsize>,
73}
74
75impl<Factory: DatabaseProviderFactory> ProofTaskManager<Factory> {
76    /// Creates a new [`ProofTaskManager`] with the given max concurrency, creating that number of
77    /// cursor factories.
78    ///
79    /// Returns an error if the consistent view provider fails to create a read-only transaction.
80    pub fn new(
81        executor: Handle,
82        view: ConsistentDbView<Factory>,
83        task_ctx: ProofTaskCtx,
84        max_concurrency: usize,
85    ) -> Self {
86        let (tx_sender, proof_task_rx) = channel();
87        Self {
88            max_concurrency,
89            total_transactions: 0,
90            view,
91            task_ctx,
92            pending_tasks: VecDeque::new(),
93            executor,
94            proof_task_txs: Vec::new(),
95            proof_task_rx,
96            tx_sender,
97            active_handles: Arc::new(AtomicUsize::new(0)),
98        }
99    }
100
101    /// Returns a handle for sending new proof tasks to the [`ProofTaskManager`].
102    pub fn handle(&self) -> ProofTaskManagerHandle<FactoryTx<Factory>> {
103        ProofTaskManagerHandle::new(self.tx_sender.clone(), self.active_handles.clone())
104    }
105}
106
107impl<Factory> ProofTaskManager<Factory>
108where
109    Factory: DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider + 'static,
110{
111    /// Inserts the task into the pending tasks queue.
112    pub fn queue_proof_task(&mut self, task: ProofTaskKind) {
113        self.pending_tasks.push_back(task);
114    }
115
116    /// Gets either the next available transaction, or creates a new one if all are in use and the
117    /// total number of transactions created is less than the max concurrency.
118    pub fn get_or_create_tx(&mut self) -> ProviderResult<Option<ProofTaskTx<FactoryTx<Factory>>>> {
119        if let Some(proof_task_tx) = self.proof_task_txs.pop() {
120            return Ok(Some(proof_task_tx));
121        }
122
123        // if we can create a new tx within our concurrency limits, create one on-demand
124        if self.total_transactions < self.max_concurrency {
125            let provider_ro = self.view.provider_ro()?;
126            let tx = provider_ro.into_tx();
127            self.total_transactions += 1;
128            return Ok(Some(ProofTaskTx::new(tx, self.task_ctx.clone())));
129        }
130
131        Ok(None)
132    }
133
134    /// Spawns the next queued proof task on the executor with the given input, if there are any
135    /// transactions available.
136    ///
137    /// This will return an error if a transaction must be created on-demand and the consistent view
138    /// provider fails.
139    pub fn try_spawn_next(&mut self) -> ProviderResult<()> {
140        let Some(task) = self.pending_tasks.pop_front() else { return Ok(()) };
141
142        let Some(proof_task_tx) = self.get_or_create_tx()? else {
143            // if there are no txs available, requeue the proof task
144            self.pending_tasks.push_front(task);
145            return Ok(())
146        };
147
148        let tx_sender = self.tx_sender.clone();
149        self.executor.spawn_blocking(move || match task {
150            ProofTaskKind::StorageProof(input, sender) => {
151                proof_task_tx.storage_proof(input, sender, tx_sender);
152            }
153            ProofTaskKind::BlindedAccountNode(path, sender) => {
154                proof_task_tx.blinded_account_node(path, sender, tx_sender);
155            }
156            ProofTaskKind::BlindedStorageNode(account, path, sender) => {
157                proof_task_tx.blinded_storage_node(account, path, sender, tx_sender);
158            }
159        });
160
161        Ok(())
162    }
163
164    /// Loops, managing the proof tasks, and sending new tasks to the executor.
165    pub fn run(mut self) -> ProviderResult<()> {
166        loop {
167            match self.proof_task_rx.recv() {
168                Ok(message) => match message {
169                    ProofTaskMessage::QueueTask(task) => {
170                        // queue the task
171                        self.queue_proof_task(task)
172                    }
173                    ProofTaskMessage::Transaction(tx) => {
174                        // return the transaction to the pool
175                        self.proof_task_txs.push(tx);
176                    }
177                    ProofTaskMessage::Terminate => return Ok(()),
178                },
179                // All senders are disconnected, so we can terminate
180                // However this should never happen, as this struct stores a sender
181                Err(_) => return Ok(()),
182            };
183
184            // try spawning the next task
185            self.try_spawn_next()?;
186        }
187    }
188}
189
190/// This contains all information shared between all storage proof instances.
191#[derive(Debug)]
192pub struct ProofTaskTx<Tx> {
193    /// The tx that is reused for proof calculations.
194    tx: Tx,
195
196    /// Trie updates, prefix sets, and state updates
197    task_ctx: ProofTaskCtx,
198}
199
200impl<Tx> ProofTaskTx<Tx> {
201    /// Initializes a [`ProofTaskTx`] using the given transaction anda[`ProofTaskCtx`].
202    const fn new(tx: Tx, task_ctx: ProofTaskCtx) -> Self {
203        Self { tx, task_ctx }
204    }
205}
206
207impl<Tx> ProofTaskTx<Tx>
208where
209    Tx: DbTx,
210{
211    fn create_factories(
212        &self,
213    ) -> (
214        InMemoryTrieCursorFactory<'_, DatabaseTrieCursorFactory<'_, Tx>>,
215        HashedPostStateCursorFactory<'_, DatabaseHashedCursorFactory<'_, Tx>>,
216    ) {
217        let trie_cursor_factory = InMemoryTrieCursorFactory::new(
218            DatabaseTrieCursorFactory::new(&self.tx),
219            &self.task_ctx.nodes_sorted,
220        );
221
222        let hashed_cursor_factory = HashedPostStateCursorFactory::new(
223            DatabaseHashedCursorFactory::new(&self.tx),
224            &self.task_ctx.state_sorted,
225        );
226
227        (trie_cursor_factory, hashed_cursor_factory)
228    }
229
230    /// Calculates a storage proof for the given hashed address, and desired prefix set.
231    fn storage_proof(
232        self,
233        input: StorageProofInput,
234        result_sender: Sender<StorageProofResult>,
235        tx_sender: Sender<ProofTaskMessage<Tx>>,
236    ) {
237        debug!(
238            target: "trie::proof_task",
239            hashed_address=?input.hashed_address,
240            "Starting storage proof task calculation"
241        );
242
243        let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
244
245        let target_slots_len = input.target_slots.len();
246        let proof_start = Instant::now();
247        let result = StorageProof::new_hashed(
248            trie_cursor_factory,
249            hashed_cursor_factory,
250            input.hashed_address,
251        )
252        .with_prefix_set_mut(PrefixSetMut::from(input.prefix_set.iter().cloned()))
253        .with_branch_node_masks(input.with_branch_node_masks)
254        .storage_multiproof(input.target_slots)
255        .map_err(|e| ParallelStateRootError::Other(e.to_string()));
256
257        debug!(
258            target: "trie::proof_task",
259            hashed_address=?input.hashed_address,
260            prefix_set = ?input.prefix_set.len(),
261            target_slots = ?target_slots_len,
262            proof_time = ?proof_start.elapsed(),
263            "Completed storage proof task calculation"
264        );
265
266        // send the result back
267        if let Err(error) = result_sender.send(result) {
268            debug!(
269                target: "trie::proof_task",
270                hashed_address = ?input.hashed_address,
271                ?error,
272                task_time = ?proof_start.elapsed(),
273                "Failed to send proof result"
274            );
275        }
276
277        // send the tx back
278        let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
279    }
280
281    /// Retrieves blinded account node by path.
282    fn blinded_account_node(
283        self,
284        path: Nibbles,
285        result_sender: Sender<BlindedNodeResult>,
286        tx_sender: Sender<ProofTaskMessage<Tx>>,
287    ) {
288        debug!(
289            target: "trie::proof_task",
290            ?path,
291            "Starting blinded account node retrieval"
292        );
293
294        let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
295
296        let blinded_provider_factory = ProofBlindedProviderFactory::new(
297            trie_cursor_factory,
298            hashed_cursor_factory,
299            self.task_ctx.prefix_sets.clone(),
300        );
301
302        let start = Instant::now();
303        let result = blinded_provider_factory.account_node_provider().blinded_node(&path);
304        debug!(
305            target: "trie::proof_task",
306            ?path,
307            elapsed = ?start.elapsed(),
308            "Completed blinded account node retrieval"
309        );
310
311        if let Err(error) = result_sender.send(result) {
312            tracing::error!(
313                target: "trie::proof_task",
314                ?path,
315                ?error,
316                "Failed to send blinded account node result"
317            );
318        }
319
320        // send the tx back
321        let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
322    }
323
324    /// Retrieves blinded storage node of the given account by path.
325    fn blinded_storage_node(
326        self,
327        account: B256,
328        path: Nibbles,
329        result_sender: Sender<BlindedNodeResult>,
330        tx_sender: Sender<ProofTaskMessage<Tx>>,
331    ) {
332        debug!(
333            target: "trie::proof_task",
334            ?account,
335            ?path,
336            "Starting blinded storage node retrieval"
337        );
338
339        let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
340
341        let blinded_provider_factory = ProofBlindedProviderFactory::new(
342            trie_cursor_factory,
343            hashed_cursor_factory,
344            self.task_ctx.prefix_sets.clone(),
345        );
346
347        let start = Instant::now();
348        let result = blinded_provider_factory.storage_node_provider(account).blinded_node(&path);
349        debug!(
350            target: "trie::proof_task",
351            ?account,
352            ?path,
353            elapsed = ?start.elapsed(),
354            "Completed blinded storage node retrieval"
355        );
356
357        if let Err(error) = result_sender.send(result) {
358            tracing::error!(
359                target: "trie::proof_task",
360                ?account,
361                ?path,
362                ?error,
363                "Failed to send blinded storage node result"
364            );
365        }
366
367        // send the tx back
368        let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
369    }
370}
371
372/// This represents an input for a storage proof.
373#[derive(Debug)]
374pub struct StorageProofInput {
375    /// The hashed address for which the proof is calculated.
376    hashed_address: B256,
377    /// The prefix set for the proof calculation.
378    prefix_set: PrefixSet,
379    /// The target slots for the proof calculation.
380    target_slots: B256Set,
381    /// Whether or not to collect branch node masks
382    with_branch_node_masks: bool,
383}
384
385impl StorageProofInput {
386    /// Creates a new [`StorageProofInput`] with the given hashed address, prefix set, and target
387    /// slots.
388    pub const fn new(
389        hashed_address: B256,
390        prefix_set: PrefixSet,
391        target_slots: B256Set,
392        with_branch_node_masks: bool,
393    ) -> Self {
394        Self { hashed_address, prefix_set, target_slots, with_branch_node_masks }
395    }
396}
397
398/// Data used for initializing cursor factories that is shared across all storage proof instances.
399#[derive(Debug, Clone)]
400pub struct ProofTaskCtx {
401    /// The sorted collection of cached in-memory intermediate trie nodes that can be reused for
402    /// computation.
403    nodes_sorted: Arc<TrieUpdatesSorted>,
404    /// The sorted in-memory overlay hashed state.
405    state_sorted: Arc<HashedPostStateSorted>,
406    /// The collection of prefix sets for the computation. Since the prefix sets _always_
407    /// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
408    /// if we have cached nodes for them.
409    prefix_sets: Arc<TriePrefixSetsMut>,
410}
411
412impl ProofTaskCtx {
413    /// Creates a new [`ProofTaskCtx`] with the given sorted nodes and state.
414    pub const fn new(
415        nodes_sorted: Arc<TrieUpdatesSorted>,
416        state_sorted: Arc<HashedPostStateSorted>,
417        prefix_sets: Arc<TriePrefixSetsMut>,
418    ) -> Self {
419        Self { nodes_sorted, state_sorted, prefix_sets }
420    }
421}
422
423/// Message used to communicate with [`ProofTaskManager`].
424#[derive(Debug)]
425pub enum ProofTaskMessage<Tx> {
426    /// A request to queue a proof task.
427    QueueTask(ProofTaskKind),
428    /// A returned database transaction.
429    Transaction(ProofTaskTx<Tx>),
430    /// A request to terminate the proof task manager.
431    Terminate,
432}
433
434/// Proof task kind.
435///
436/// When queueing a task using [`ProofTaskMessage::QueueTask`], this enum
437/// specifies the type of proof task to be executed.
438#[derive(Debug)]
439pub enum ProofTaskKind {
440    /// A storage proof request.
441    StorageProof(StorageProofInput, Sender<StorageProofResult>),
442    /// A blinded account node request.
443    BlindedAccountNode(Nibbles, Sender<BlindedNodeResult>),
444    /// A blinded storage node request.
445    BlindedStorageNode(B256, Nibbles, Sender<BlindedNodeResult>),
446}
447
448/// A handle that wraps a single proof task sender that sends a terminate message on `Drop` if the
449/// number of active handles went to zero.
450#[derive(Debug)]
451pub struct ProofTaskManagerHandle<Tx> {
452    /// The sender for the proof task manager.
453    sender: Sender<ProofTaskMessage<Tx>>,
454    /// The number of active handles.
455    active_handles: Arc<AtomicUsize>,
456}
457
458impl<Tx> ProofTaskManagerHandle<Tx> {
459    /// Creates a new [`ProofTaskManagerHandle`] with the given sender.
460    pub fn new(sender: Sender<ProofTaskMessage<Tx>>, active_handles: Arc<AtomicUsize>) -> Self {
461        active_handles.fetch_add(1, Ordering::SeqCst);
462        Self { sender, active_handles }
463    }
464
465    /// Queues a task to the proof task manager.
466    pub fn queue_task(&self, task: ProofTaskKind) -> Result<(), SendError<ProofTaskMessage<Tx>>> {
467        self.sender.send(ProofTaskMessage::QueueTask(task))
468    }
469
470    /// Terminates the proof task manager.
471    pub fn terminate(&self) {
472        let _ = self.sender.send(ProofTaskMessage::Terminate);
473    }
474}
475
476impl<Tx> Clone for ProofTaskManagerHandle<Tx> {
477    fn clone(&self) -> Self {
478        Self::new(self.sender.clone(), self.active_handles.clone())
479    }
480}
481
482impl<Tx> Drop for ProofTaskManagerHandle<Tx> {
483    fn drop(&mut self) {
484        // Decrement the number of active handles and terminate the manager if it was the last
485        // handle.
486        if self.active_handles.fetch_sub(1, Ordering::SeqCst) == 1 {
487            self.terminate();
488        }
489    }
490}
491
492impl<Tx: DbTx> BlindedProviderFactory for ProofTaskManagerHandle<Tx> {
493    type AccountNodeProvider = ProofTaskBlindedNodeProvider<Tx>;
494    type StorageNodeProvider = ProofTaskBlindedNodeProvider<Tx>;
495
496    fn account_node_provider(&self) -> Self::AccountNodeProvider {
497        ProofTaskBlindedNodeProvider::AccountNode { sender: self.sender.clone() }
498    }
499
500    fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
501        ProofTaskBlindedNodeProvider::StorageNode { account, sender: self.sender.clone() }
502    }
503}
504
505/// Blinded node provider for retrieving trie nodes by path.
506#[derive(Debug)]
507pub enum ProofTaskBlindedNodeProvider<Tx> {
508    /// Blinded account trie node provider.
509    AccountNode {
510        /// Sender to the proof task.
511        sender: Sender<ProofTaskMessage<Tx>>,
512    },
513    /// Blinded storage trie node provider.
514    StorageNode {
515        /// Target account.
516        account: B256,
517        /// Sender to the proof task.
518        sender: Sender<ProofTaskMessage<Tx>>,
519    },
520}
521
522impl<Tx: DbTx> BlindedProvider for ProofTaskBlindedNodeProvider<Tx> {
523    fn blinded_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
524        let (tx, rx) = channel();
525        match self {
526            Self::AccountNode { sender } => {
527                let _ = sender.send(ProofTaskMessage::QueueTask(
528                    ProofTaskKind::BlindedAccountNode(path.clone(), tx),
529                ));
530            }
531            Self::StorageNode { sender, account } => {
532                let _ = sender.send(ProofTaskMessage::QueueTask(
533                    ProofTaskKind::BlindedStorageNode(*account, path.clone(), tx),
534                ));
535            }
536        }
537
538        rx.recv().unwrap()
539    }
540}