reth_trie_parallel/
proof_task.rs

1//! A Task that manages sending proof requests to a number of tasks that have longer-running
2//! database transactions.
3//!
4//! The [`ProofTaskManager`] ensures that there are a max number of currently executing proof tasks,
5//! and is responsible for managing the fixed number of database transactions created at the start
6//! of the task.
7//!
8//! Individual [`ProofTaskTx`] instances manage a dedicated [`InMemoryTrieCursorFactory`] and
9//! [`HashedPostStateCursorFactory`], which are each backed by a database transaction.
10
11use crate::root::ParallelStateRootError;
12use alloy_primitives::{map::B256Set, B256};
13use crossbeam_channel::{unbounded, Receiver as CrossbeamReceiver, Sender as CrossbeamSender};
14use reth_db_api::transaction::DbTx;
15use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind};
16use reth_provider::{
17    providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
18    ProviderResult,
19};
20use reth_trie::{
21    hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
22    prefix_set::TriePrefixSetsMut,
23    proof::{ProofTrieNodeProviderFactory, StorageProof},
24    trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
25    updates::TrieUpdatesSorted,
26    DecodedStorageMultiProof, HashedPostStateSorted, Nibbles,
27};
28use reth_trie_common::{
29    added_removed_keys::MultiAddedRemovedKeys,
30    prefix_set::{PrefixSet, PrefixSetMut},
31};
32use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
33use reth_trie_sparse::provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory};
34use std::{
35    collections::VecDeque,
36    sync::{
37        atomic::{AtomicUsize, Ordering},
38        mpsc::{channel, Receiver, SendError, Sender},
39        Arc,
40    },
41    time::Instant,
42};
43use tokio::runtime::Handle;
44use tracing::trace;
45
46#[cfg(feature = "metrics")]
47use crate::proof_task_metrics::ProofTaskMetrics;
48
49type StorageProofResult = Result<DecodedStorageMultiProof, ParallelStateRootError>;
50type TrieNodeProviderResult = Result<Option<RevealedNode>, SparseTrieError>;
51
52/// Internal message for storage workers.
53///
54/// This is NOT exposed publicly. External callers use `ProofTaskKind::StorageProof` or
55/// `ProofTaskKind::BlindedStorageNode` which are routed through the manager's `std::mpsc` channel.
56#[derive(Debug)]
57enum StorageWorkerJob {
58    /// Storage proof computation request
59    StorageProof {
60        /// Storage proof input parameters
61        input: StorageProofInput,
62        /// Channel to send result back to original caller
63        result_sender: Sender<StorageProofResult>,
64    },
65    /// Blinded storage node retrieval request
66    BlindedStorageNode {
67        /// Target account
68        account: B256,
69        /// Path to the storage node
70        path: Nibbles,
71        /// Channel to send result back to original caller
72        result_sender: Sender<TrieNodeProviderResult>,
73    },
74}
75
76impl StorageWorkerJob {
77    /// Sends an error back to the caller when worker pool is unavailable.
78    ///
79    /// Returns `Ok(())` if the error was sent successfully, or `Err(())` if the receiver was
80    /// dropped.
81    fn send_worker_unavailable_error(&self) -> Result<(), ()> {
82        let error =
83            ParallelStateRootError::Other("Storage proof worker pool unavailable".to_string());
84
85        match self {
86            Self::StorageProof { result_sender, .. } => {
87                result_sender.send(Err(error)).map_err(|_| ())
88            }
89            Self::BlindedStorageNode { result_sender, .. } => result_sender
90                .send(Err(SparseTrieError::from(SparseTrieErrorKind::Other(Box::new(error)))))
91                .map_err(|_| ()),
92        }
93    }
94}
95
96/// Manager for coordinating proof request execution across different task types.
97///
98/// # Architecture
99///
100/// This manager handles two distinct execution paths:
101///
102/// 1. **Storage Worker Pool** (for storage trie operations):
103///    - Pre-spawned workers with dedicated long-lived transactions
104///    - Handles `StorageProof` and `BlindedStorageNode` requests
105///    - Tasks queued via crossbeam unbounded channel
106///    - Workers continuously process without transaction overhead
107///    - Unbounded queue ensures all storage proofs benefit from transaction reuse
108///
109/// 2. **On-Demand Execution** (for account trie operations):
110///    - Lazy transaction creation for `BlindedAccountNode` requests
111///    - Transactions returned to pool after use for reuse
112///
113/// # Public Interface
114///
115/// The public interface through `ProofTaskManagerHandle` allows external callers to:
116/// - Submit tasks via `queue_task(ProofTaskKind)`
117/// - Use standard `std::mpsc` message passing
118/// - Receive consistent return types and error handling
119#[derive(Debug)]
120pub struct ProofTaskManager<Factory: DatabaseProviderFactory> {
121    /// Sender for storage worker jobs to worker pool.
122    storage_work_tx: CrossbeamSender<StorageWorkerJob>,
123
124    /// Number of storage workers successfully spawned.
125    ///
126    /// May be less than requested if concurrency limits reduce the worker budget.
127    storage_worker_count: usize,
128
129    /// Max number of database transactions to create for on-demand account trie operations.
130    max_concurrency: usize,
131
132    /// Number of database transactions created for on-demand operations.
133    total_transactions: usize,
134
135    /// Proof tasks pending execution (account trie operations only).
136    pending_tasks: VecDeque<ProofTaskKind>,
137
138    /// The proof task transactions, containing owned cursor factories that are reused for proof
139    /// calculation (account trie operations only).
140    proof_task_txs: Vec<ProofTaskTx<FactoryTx<Factory>>>,
141
142    /// Consistent view provider used for creating transactions on-demand.
143    view: ConsistentDbView<Factory>,
144
145    /// Proof task context shared across all proof tasks.
146    task_ctx: ProofTaskCtx,
147
148    /// The underlying handle from which to spawn proof tasks.
149    executor: Handle,
150
151    /// Receives proof task requests from [`ProofTaskManagerHandle`].
152    proof_task_rx: Receiver<ProofTaskMessage<FactoryTx<Factory>>>,
153
154    /// Internal channel for on-demand tasks to return transactions after use.
155    tx_sender: Sender<ProofTaskMessage<FactoryTx<Factory>>>,
156
157    /// The number of active handles.
158    ///
159    /// Incremented in [`ProofTaskManagerHandle::new`] and decremented in
160    /// [`ProofTaskManagerHandle::drop`].
161    active_handles: Arc<AtomicUsize>,
162
163    /// Metrics tracking proof task operations.
164    #[cfg(feature = "metrics")]
165    metrics: ProofTaskMetrics,
166}
167
168/// Worker loop for storage trie operations.
169///
170/// # Lifecycle
171///
172/// Each worker:
173/// 1. Receives `StorageWorkerJob` from crossbeam unbounded channel
174/// 2. Computes result using its dedicated long-lived transaction
175/// 3. Sends result directly to original caller via `std::mpsc`
176/// 4. Repeats until channel closes (graceful shutdown)
177///
178/// # Transaction Reuse
179///
180/// Reuses the same transaction and cursor factories across multiple operations
181/// to avoid transaction creation and cursor factory setup overhead.
182///
183/// # Panic Safety
184///
185/// If this function panics, the worker thread terminates but other workers
186/// continue operating and the system degrades gracefully.
187///
188/// # Shutdown
189///
190/// Worker shuts down when the crossbeam channel closes (all senders dropped).
191fn storage_worker_loop<Tx>(
192    proof_tx: ProofTaskTx<Tx>,
193    work_rx: CrossbeamReceiver<StorageWorkerJob>,
194    worker_id: usize,
195) where
196    Tx: DbTx,
197{
198    tracing::debug!(
199        target: "trie::proof_task",
200        worker_id,
201        "Storage worker started"
202    );
203
204    // Create factories once at worker startup to avoid recreation overhead.
205    let (trie_cursor_factory, hashed_cursor_factory) = proof_tx.create_factories();
206
207    // Create blinded provider factory once for all blinded node requests
208    let blinded_provider_factory = ProofTrieNodeProviderFactory::new(
209        trie_cursor_factory.clone(),
210        hashed_cursor_factory.clone(),
211        proof_tx.task_ctx.prefix_sets.clone(),
212    );
213
214    let mut storage_proofs_processed = 0u64;
215    let mut storage_nodes_processed = 0u64;
216
217    while let Ok(job) = work_rx.recv() {
218        match job {
219            StorageWorkerJob::StorageProof { input, result_sender } => {
220                let hashed_address = input.hashed_address;
221
222                trace!(
223                    target: "trie::proof_task",
224                    worker_id,
225                    hashed_address = ?hashed_address,
226                    prefix_set_len = input.prefix_set.len(),
227                    target_slots = input.target_slots.len(),
228                    "Processing storage proof"
229                );
230
231                let proof_start = Instant::now();
232                let result = proof_tx.compute_storage_proof(
233                    input,
234                    trie_cursor_factory.clone(),
235                    hashed_cursor_factory.clone(),
236                );
237
238                let proof_elapsed = proof_start.elapsed();
239                storage_proofs_processed += 1;
240
241                if result_sender.send(result).is_err() {
242                    tracing::debug!(
243                        target: "trie::proof_task",
244                        worker_id,
245                        hashed_address = ?hashed_address,
246                        storage_proofs_processed,
247                        "Storage proof receiver dropped, discarding result"
248                    );
249                }
250
251                trace!(
252                    target: "trie::proof_task",
253                    worker_id,
254                    hashed_address = ?hashed_address,
255                    proof_time_us = proof_elapsed.as_micros(),
256                    total_processed = storage_proofs_processed,
257                    "Storage proof completed"
258                );
259            }
260
261            StorageWorkerJob::BlindedStorageNode { account, path, result_sender } => {
262                trace!(
263                    target: "trie::proof_task",
264                    worker_id,
265                    ?account,
266                    ?path,
267                    "Processing blinded storage node"
268                );
269
270                let start = Instant::now();
271                let result =
272                    blinded_provider_factory.storage_node_provider(account).trie_node(&path);
273                let elapsed = start.elapsed();
274
275                storage_nodes_processed += 1;
276
277                if result_sender.send(result).is_err() {
278                    tracing::debug!(
279                        target: "trie::proof_task",
280                        worker_id,
281                        ?account,
282                        ?path,
283                        storage_nodes_processed,
284                        "Blinded storage node receiver dropped, discarding result"
285                    );
286                }
287
288                trace!(
289                    target: "trie::proof_task",
290                    worker_id,
291                    ?account,
292                    ?path,
293                    elapsed_us = elapsed.as_micros(),
294                    total_processed = storage_nodes_processed,
295                    "Blinded storage node completed"
296                );
297            }
298        }
299    }
300
301    tracing::debug!(
302        target: "trie::proof_task",
303        worker_id,
304        storage_proofs_processed,
305        storage_nodes_processed,
306        "Storage worker shutting down"
307    );
308}
309
310impl<Factory> ProofTaskManager<Factory>
311where
312    Factory: DatabaseProviderFactory<Provider: BlockReader>,
313{
314    /// Creates a new [`ProofTaskManager`] with pre-spawned storage proof workers.
315    ///
316    /// The `storage_worker_count` determines how many storage workers to spawn, and
317    /// `max_concurrency` determines the limit for on-demand operations (blinded account nodes).
318    /// These are now independent - storage workers are spawned as requested, and on-demand
319    /// operations use a separate concurrency pool for blinded account nodes.
320    /// Returns an error if the underlying provider fails to create the transactions required for
321    /// spawning workers.
322    pub fn new(
323        executor: Handle,
324        view: ConsistentDbView<Factory>,
325        task_ctx: ProofTaskCtx,
326        max_concurrency: usize,
327        storage_worker_count: usize,
328    ) -> ProviderResult<Self> {
329        let (tx_sender, proof_task_rx) = channel();
330
331        // Use unbounded channel to ensure all storage operations are queued to workers.
332        // This maintains transaction reuse benefits and avoids fallback to on-demand execution.
333        let (storage_work_tx, storage_work_rx) = unbounded::<StorageWorkerJob>();
334
335        tracing::info!(
336            target: "trie::proof_task",
337            storage_worker_count,
338            max_concurrency,
339            "Initializing storage worker pool with unbounded queue"
340        );
341
342        let mut spawned_workers = 0;
343        for worker_id in 0..storage_worker_count {
344            let provider_ro = view.provider_ro()?;
345
346            let tx = provider_ro.into_tx();
347            let proof_task_tx = ProofTaskTx::new(tx, task_ctx.clone(), worker_id);
348            let work_rx = storage_work_rx.clone();
349
350            executor.spawn_blocking(move || storage_worker_loop(proof_task_tx, work_rx, worker_id));
351
352            spawned_workers += 1;
353
354            tracing::debug!(
355                target: "trie::proof_task",
356                worker_id,
357                spawned_workers,
358                "Storage worker spawned successfully"
359            );
360        }
361
362        Ok(Self {
363            storage_work_tx,
364            storage_worker_count: spawned_workers,
365            max_concurrency,
366            total_transactions: 0,
367            pending_tasks: VecDeque::new(),
368            proof_task_txs: Vec::with_capacity(max_concurrency),
369            view,
370            task_ctx,
371            executor,
372            proof_task_rx,
373            tx_sender,
374            active_handles: Arc::new(AtomicUsize::new(0)),
375
376            #[cfg(feature = "metrics")]
377            metrics: ProofTaskMetrics::default(),
378        })
379    }
380
381    /// Returns a handle for sending new proof tasks to the [`ProofTaskManager`].
382    pub fn handle(&self) -> ProofTaskManagerHandle<FactoryTx<Factory>> {
383        ProofTaskManagerHandle::new(self.tx_sender.clone(), self.active_handles.clone())
384    }
385}
386
387impl<Factory> ProofTaskManager<Factory>
388where
389    Factory: DatabaseProviderFactory<Provider: BlockReader> + 'static,
390{
391    /// Inserts the task into the pending tasks queue.
392    pub fn queue_proof_task(&mut self, task: ProofTaskKind) {
393        self.pending_tasks.push_back(task);
394    }
395
396    /// Gets either the next available transaction, or creates a new one if all are in use and the
397    /// total number of transactions created is less than the max concurrency.
398    pub fn get_or_create_tx(&mut self) -> ProviderResult<Option<ProofTaskTx<FactoryTx<Factory>>>> {
399        if let Some(proof_task_tx) = self.proof_task_txs.pop() {
400            return Ok(Some(proof_task_tx));
401        }
402
403        // if we can create a new tx within our concurrency limits, create one on-demand
404        if self.total_transactions < self.max_concurrency {
405            let provider_ro = self.view.provider_ro()?;
406            let tx = provider_ro.into_tx();
407            self.total_transactions += 1;
408            return Ok(Some(ProofTaskTx::new(tx, self.task_ctx.clone(), self.total_transactions)));
409        }
410
411        Ok(None)
412    }
413
414    /// Spawns the next queued proof task on the executor with the given input, if there are any
415    /// transactions available.
416    ///
417    /// This will return an error if a transaction must be created on-demand and the consistent view
418    /// provider fails.
419    pub fn try_spawn_next(&mut self) -> ProviderResult<()> {
420        let Some(task) = self.pending_tasks.pop_front() else { return Ok(()) };
421
422        let Some(proof_task_tx) = self.get_or_create_tx()? else {
423            // if there are no txs available, requeue the proof task
424            self.pending_tasks.push_front(task);
425            return Ok(())
426        };
427
428        let tx_sender = self.tx_sender.clone();
429        self.executor.spawn_blocking(move || match task {
430            ProofTaskKind::BlindedAccountNode(path, sender) => {
431                proof_task_tx.blinded_account_node(path, sender, tx_sender);
432            }
433            // Storage trie operations should never reach here as they're routed to worker pool
434            ProofTaskKind::BlindedStorageNode(_, _, _) | ProofTaskKind::StorageProof(_, _) => {
435                unreachable!("Storage trie operations should be routed to worker pool")
436            }
437        });
438
439        Ok(())
440    }
441
442    /// Loops, managing the proof tasks, and sending new tasks to the executor.
443    ///
444    /// # Task Routing
445    ///
446    /// - **Storage Trie Operations** (`StorageProof` and `BlindedStorageNode`): Routed to
447    ///   pre-spawned worker pool via unbounded channel.
448    /// - **Account Trie Operations** (`BlindedAccountNode`): Queued for on-demand execution via
449    ///   `pending_tasks`.
450    ///
451    /// # Shutdown
452    ///
453    /// On termination, `storage_work_tx` is dropped, closing the channel and
454    /// signaling all workers to shut down gracefully.
455    pub fn run(mut self) -> ProviderResult<()> {
456        loop {
457            match self.proof_task_rx.recv() {
458                Ok(message) => {
459                    match message {
460                        ProofTaskMessage::QueueTask(task) => match task {
461                            ProofTaskKind::StorageProof(input, sender) => {
462                                match self.storage_work_tx.send(StorageWorkerJob::StorageProof {
463                                    input,
464                                    result_sender: sender,
465                                }) {
466                                    Ok(_) => {
467                                        tracing::trace!(
468                                            target: "trie::proof_task",
469                                            "Storage proof dispatched to worker pool"
470                                        );
471                                    }
472                                    Err(crossbeam_channel::SendError(job)) => {
473                                        tracing::error!(
474                                            target: "trie::proof_task",
475                                            storage_worker_count = self.storage_worker_count,
476                                            "Worker pool disconnected, cannot process storage proof"
477                                        );
478
479                                        // Send error back to caller
480                                        let _ = job.send_worker_unavailable_error();
481                                    }
482                                }
483                            }
484
485                            ProofTaskKind::BlindedStorageNode(account, path, sender) => {
486                                #[cfg(feature = "metrics")]
487                                {
488                                    self.metrics.storage_nodes += 1;
489                                }
490
491                                match self.storage_work_tx.send(
492                                    StorageWorkerJob::BlindedStorageNode {
493                                        account,
494                                        path,
495                                        result_sender: sender,
496                                    },
497                                ) {
498                                    Ok(_) => {
499                                        tracing::trace!(
500                                            target: "trie::proof_task",
501                                            ?account,
502                                            ?path,
503                                            "Blinded storage node dispatched to worker pool"
504                                        );
505                                    }
506                                    Err(crossbeam_channel::SendError(job)) => {
507                                        tracing::warn!(
508                                            target: "trie::proof_task",
509                                            storage_worker_count = self.storage_worker_count,
510                                            ?account,
511                                            ?path,
512                                            "Worker pool disconnected, cannot process blinded storage node"
513                                        );
514
515                                        // Send error back to caller
516                                        let _ = job.send_worker_unavailable_error();
517                                    }
518                                }
519                            }
520
521                            ProofTaskKind::BlindedAccountNode(_, _) => {
522                                // Route account trie operations to pending_tasks
523                                #[cfg(feature = "metrics")]
524                                {
525                                    self.metrics.account_nodes += 1;
526                                }
527                                self.queue_proof_task(task);
528                            }
529                        },
530                        ProofTaskMessage::Transaction(tx) => {
531                            // Return transaction to pending_tasks pool
532                            self.proof_task_txs.push(tx);
533                        }
534                        ProofTaskMessage::Terminate => {
535                            // Drop storage_work_tx to signal workers to shut down
536                            drop(self.storage_work_tx);
537
538                            tracing::debug!(
539                                target: "trie::proof_task",
540                                storage_worker_count = self.storage_worker_count,
541                                "Shutting down proof task manager, signaling workers to terminate"
542                            );
543
544                            // Record metrics before terminating
545                            #[cfg(feature = "metrics")]
546                            self.metrics.record();
547
548                            return Ok(())
549                        }
550                    }
551                }
552                // All senders are disconnected, so we can terminate
553                // However this should never happen, as this struct stores a sender
554                Err(_) => return Ok(()),
555            };
556
557            // Try spawning pending account trie tasks
558            self.try_spawn_next()?;
559        }
560    }
561}
562
563/// Type alias for the factory tuple returned by `create_factories`
564type ProofFactories<'a, Tx> = (
565    InMemoryTrieCursorFactory<DatabaseTrieCursorFactory<&'a Tx>, &'a TrieUpdatesSorted>,
566    HashedPostStateCursorFactory<DatabaseHashedCursorFactory<&'a Tx>, &'a HashedPostStateSorted>,
567);
568
569/// This contains all information shared between all storage proof instances.
570#[derive(Debug)]
571pub struct ProofTaskTx<Tx> {
572    /// The tx that is reused for proof calculations.
573    tx: Tx,
574
575    /// Trie updates, prefix sets, and state updates
576    task_ctx: ProofTaskCtx,
577
578    /// Identifier for the tx within the context of a single [`ProofTaskManager`], used only for
579    /// tracing.
580    id: usize,
581}
582
583impl<Tx> ProofTaskTx<Tx> {
584    /// Initializes a [`ProofTaskTx`] using the given transaction and a [`ProofTaskCtx`]. The id is
585    /// used only for tracing.
586    const fn new(tx: Tx, task_ctx: ProofTaskCtx, id: usize) -> Self {
587        Self { tx, task_ctx, id }
588    }
589}
590
591impl<Tx> ProofTaskTx<Tx>
592where
593    Tx: DbTx,
594{
595    #[inline]
596    fn create_factories(&self) -> ProofFactories<'_, Tx> {
597        let trie_cursor_factory = InMemoryTrieCursorFactory::new(
598            DatabaseTrieCursorFactory::new(&self.tx),
599            self.task_ctx.nodes_sorted.as_ref(),
600        );
601
602        let hashed_cursor_factory = HashedPostStateCursorFactory::new(
603            DatabaseHashedCursorFactory::new(&self.tx),
604            self.task_ctx.state_sorted.as_ref(),
605        );
606
607        (trie_cursor_factory, hashed_cursor_factory)
608    }
609
610    /// Compute storage proof with pre-created factories.
611    ///
612    /// Accepts cursor factories as parameters to allow reuse across multiple proofs.
613    /// Used by storage workers in the worker pool to avoid factory recreation
614    /// overhead on each proof computation.
615    #[inline]
616    fn compute_storage_proof(
617        &self,
618        input: StorageProofInput,
619        trie_cursor_factory: impl TrieCursorFactory,
620        hashed_cursor_factory: impl HashedCursorFactory,
621    ) -> StorageProofResult {
622        // Consume the input so we can move large collections (e.g. target slots) without cloning.
623        let StorageProofInput {
624            hashed_address,
625            prefix_set,
626            target_slots,
627            with_branch_node_masks,
628            multi_added_removed_keys,
629        } = input;
630
631        // Get or create added/removed keys context
632        let multi_added_removed_keys =
633            multi_added_removed_keys.unwrap_or_else(|| Arc::new(MultiAddedRemovedKeys::new()));
634        let added_removed_keys = multi_added_removed_keys.get_storage(&hashed_address);
635
636        let span = tracing::trace_span!(
637            target: "trie::proof_task",
638            "Storage proof calculation",
639            hashed_address = ?hashed_address,
640            worker_id = self.id,
641        );
642        let _span_guard = span.enter();
643
644        let proof_start = Instant::now();
645
646        // Compute raw storage multiproof
647        let raw_proof_result =
648            StorageProof::new_hashed(trie_cursor_factory, hashed_cursor_factory, hashed_address)
649                .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().copied()))
650                .with_branch_node_masks(with_branch_node_masks)
651                .with_added_removed_keys(added_removed_keys)
652                .storage_multiproof(target_slots)
653                .map_err(|e| ParallelStateRootError::Other(e.to_string()));
654
655        // Decode proof into DecodedStorageMultiProof
656        let decoded_result = raw_proof_result.and_then(|raw_proof| {
657            raw_proof.try_into().map_err(|e: alloy_rlp::Error| {
658                ParallelStateRootError::Other(format!(
659                    "Failed to decode storage proof for {}: {}",
660                    hashed_address, e
661                ))
662            })
663        });
664
665        trace!(
666            target: "trie::proof_task",
667            hashed_address = ?hashed_address,
668            proof_time_us = proof_start.elapsed().as_micros(),
669            worker_id = self.id,
670            "Completed storage proof calculation"
671        );
672
673        decoded_result
674    }
675
676    /// Retrieves blinded account node by path.
677    fn blinded_account_node(
678        self,
679        path: Nibbles,
680        result_sender: Sender<TrieNodeProviderResult>,
681        tx_sender: Sender<ProofTaskMessage<Tx>>,
682    ) {
683        trace!(
684            target: "trie::proof_task",
685            ?path,
686            "Starting blinded account node retrieval"
687        );
688
689        let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
690
691        let blinded_provider_factory = ProofTrieNodeProviderFactory::new(
692            trie_cursor_factory,
693            hashed_cursor_factory,
694            self.task_ctx.prefix_sets.clone(),
695        );
696
697        let start = Instant::now();
698        let result = blinded_provider_factory.account_node_provider().trie_node(&path);
699        trace!(
700            target: "trie::proof_task",
701            ?path,
702            elapsed = ?start.elapsed(),
703            "Completed blinded account node retrieval"
704        );
705
706        if let Err(error) = result_sender.send(result) {
707            tracing::error!(
708                target: "trie::proof_task",
709                ?path,
710                ?error,
711                "Failed to send blinded account node result"
712            );
713        }
714
715        // send the tx back
716        let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
717    }
718}
719
720/// This represents an input for a storage proof.
721#[derive(Debug)]
722pub struct StorageProofInput {
723    /// The hashed address for which the proof is calculated.
724    hashed_address: B256,
725    /// The prefix set for the proof calculation.
726    prefix_set: PrefixSet,
727    /// The target slots for the proof calculation.
728    target_slots: B256Set,
729    /// Whether or not to collect branch node masks
730    with_branch_node_masks: bool,
731    /// Provided by the user to give the necessary context to retain extra proofs.
732    multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
733}
734
735impl StorageProofInput {
736    /// Creates a new [`StorageProofInput`] with the given hashed address, prefix set, and target
737    /// slots.
738    pub const fn new(
739        hashed_address: B256,
740        prefix_set: PrefixSet,
741        target_slots: B256Set,
742        with_branch_node_masks: bool,
743        multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
744    ) -> Self {
745        Self {
746            hashed_address,
747            prefix_set,
748            target_slots,
749            with_branch_node_masks,
750            multi_added_removed_keys,
751        }
752    }
753}
754
755/// Data used for initializing cursor factories that is shared across all storage proof instances.
756#[derive(Debug, Clone)]
757pub struct ProofTaskCtx {
758    /// The sorted collection of cached in-memory intermediate trie nodes that can be reused for
759    /// computation.
760    nodes_sorted: Arc<TrieUpdatesSorted>,
761    /// The sorted in-memory overlay hashed state.
762    state_sorted: Arc<HashedPostStateSorted>,
763    /// The collection of prefix sets for the computation. Since the prefix sets _always_
764    /// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
765    /// if we have cached nodes for them.
766    prefix_sets: Arc<TriePrefixSetsMut>,
767}
768
769impl ProofTaskCtx {
770    /// Creates a new [`ProofTaskCtx`] with the given sorted nodes and state.
771    pub const fn new(
772        nodes_sorted: Arc<TrieUpdatesSorted>,
773        state_sorted: Arc<HashedPostStateSorted>,
774        prefix_sets: Arc<TriePrefixSetsMut>,
775    ) -> Self {
776        Self { nodes_sorted, state_sorted, prefix_sets }
777    }
778}
779
780/// Message used to communicate with [`ProofTaskManager`].
781#[derive(Debug)]
782pub enum ProofTaskMessage<Tx> {
783    /// A request to queue a proof task.
784    QueueTask(ProofTaskKind),
785    /// A returned database transaction.
786    Transaction(ProofTaskTx<Tx>),
787    /// A request to terminate the proof task manager.
788    Terminate,
789}
790
791/// Proof task kind.
792///
793/// When queueing a task using [`ProofTaskMessage::QueueTask`], this enum
794/// specifies the type of proof task to be executed.
795#[derive(Debug)]
796pub enum ProofTaskKind {
797    /// A storage proof request.
798    StorageProof(StorageProofInput, Sender<StorageProofResult>),
799    /// A blinded account node request.
800    BlindedAccountNode(Nibbles, Sender<TrieNodeProviderResult>),
801    /// A blinded storage node request.
802    BlindedStorageNode(B256, Nibbles, Sender<TrieNodeProviderResult>),
803}
804
805/// A handle that wraps a single proof task sender that sends a terminate message on `Drop` if the
806/// number of active handles went to zero.
807#[derive(Debug)]
808pub struct ProofTaskManagerHandle<Tx> {
809    /// The sender for the proof task manager.
810    sender: Sender<ProofTaskMessage<Tx>>,
811    /// The number of active handles.
812    active_handles: Arc<AtomicUsize>,
813}
814
815impl<Tx> ProofTaskManagerHandle<Tx> {
816    /// Creates a new [`ProofTaskManagerHandle`] with the given sender.
817    pub fn new(sender: Sender<ProofTaskMessage<Tx>>, active_handles: Arc<AtomicUsize>) -> Self {
818        active_handles.fetch_add(1, Ordering::SeqCst);
819        Self { sender, active_handles }
820    }
821
822    /// Queues a task to the proof task manager.
823    pub fn queue_task(&self, task: ProofTaskKind) -> Result<(), SendError<ProofTaskMessage<Tx>>> {
824        self.sender.send(ProofTaskMessage::QueueTask(task))
825    }
826
827    /// Terminates the proof task manager.
828    pub fn terminate(&self) {
829        let _ = self.sender.send(ProofTaskMessage::Terminate);
830    }
831}
832
833impl<Tx> Clone for ProofTaskManagerHandle<Tx> {
834    fn clone(&self) -> Self {
835        Self::new(self.sender.clone(), self.active_handles.clone())
836    }
837}
838
839impl<Tx> Drop for ProofTaskManagerHandle<Tx> {
840    fn drop(&mut self) {
841        // Decrement the number of active handles and terminate the manager if it was the last
842        // handle.
843        if self.active_handles.fetch_sub(1, Ordering::SeqCst) == 1 {
844            self.terminate();
845        }
846    }
847}
848
849impl<Tx: DbTx> TrieNodeProviderFactory for ProofTaskManagerHandle<Tx> {
850    type AccountNodeProvider = ProofTaskTrieNodeProvider<Tx>;
851    type StorageNodeProvider = ProofTaskTrieNodeProvider<Tx>;
852
853    fn account_node_provider(&self) -> Self::AccountNodeProvider {
854        ProofTaskTrieNodeProvider::AccountNode { sender: self.sender.clone() }
855    }
856
857    fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
858        ProofTaskTrieNodeProvider::StorageNode { account, sender: self.sender.clone() }
859    }
860}
861
862/// Trie node provider for retrieving trie nodes by path.
863#[derive(Debug)]
864pub enum ProofTaskTrieNodeProvider<Tx> {
865    /// Blinded account trie node provider.
866    AccountNode {
867        /// Sender to the proof task.
868        sender: Sender<ProofTaskMessage<Tx>>,
869    },
870    /// Blinded storage trie node provider.
871    StorageNode {
872        /// Target account.
873        account: B256,
874        /// Sender to the proof task.
875        sender: Sender<ProofTaskMessage<Tx>>,
876    },
877}
878
879impl<Tx: DbTx> TrieNodeProvider for ProofTaskTrieNodeProvider<Tx> {
880    fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
881        let (tx, rx) = channel();
882        match self {
883            Self::AccountNode { sender } => {
884                let _ = sender.send(ProofTaskMessage::QueueTask(
885                    ProofTaskKind::BlindedAccountNode(*path, tx),
886                ));
887            }
888            Self::StorageNode { sender, account } => {
889                let _ = sender.send(ProofTaskMessage::QueueTask(
890                    ProofTaskKind::BlindedStorageNode(*account, *path, tx),
891                ));
892            }
893        }
894
895        rx.recv().unwrap()
896    }
897}
898
899#[cfg(test)]
900mod tests {
901    use super::*;
902    use alloy_primitives::map::B256Map;
903    use reth_provider::{providers::ConsistentDbView, test_utils::create_test_provider_factory};
904    use reth_trie_common::{
905        prefix_set::TriePrefixSetsMut, updates::TrieUpdatesSorted, HashedAccountsSorted,
906        HashedPostStateSorted,
907    };
908    use std::sync::Arc;
909    use tokio::{runtime::Builder, task};
910
911    fn test_ctx() -> ProofTaskCtx {
912        ProofTaskCtx::new(
913            Arc::new(TrieUpdatesSorted::default()),
914            Arc::new(HashedPostStateSorted::new(
915                HashedAccountsSorted::default(),
916                B256Map::default(),
917            )),
918            Arc::new(TriePrefixSetsMut::default()),
919        )
920    }
921
922    /// Ensures `max_concurrency` is independent of storage workers.
923    #[test]
924    fn proof_task_manager_independent_pools() {
925        let runtime = Builder::new_multi_thread().worker_threads(1).enable_all().build().unwrap();
926        runtime.block_on(async {
927            let handle = tokio::runtime::Handle::current();
928            let factory = create_test_provider_factory();
929            let view = ConsistentDbView::new(factory, None);
930            let ctx = test_ctx();
931
932            let manager = ProofTaskManager::new(handle.clone(), view, ctx, 1, 5).unwrap();
933            // With storage_worker_count=5, we get exactly 5 workers
934            assert_eq!(manager.storage_worker_count, 5);
935            // max_concurrency=1 is for on-demand operations only
936            assert_eq!(manager.max_concurrency, 1);
937
938            drop(manager);
939            task::yield_now().await;
940        });
941    }
942}