1use crate::root::ParallelStateRootError;
12use alloy_primitives::{map::B256Set, B256};
13use crossbeam_channel::{unbounded, Receiver as CrossbeamReceiver, Sender as CrossbeamSender};
14use reth_db_api::transaction::DbTx;
15use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind};
16use reth_provider::{
17 providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
18 ProviderResult,
19};
20use reth_trie::{
21 hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
22 prefix_set::TriePrefixSetsMut,
23 proof::{ProofTrieNodeProviderFactory, StorageProof},
24 trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
25 updates::TrieUpdatesSorted,
26 DecodedStorageMultiProof, HashedPostStateSorted, Nibbles,
27};
28use reth_trie_common::{
29 added_removed_keys::MultiAddedRemovedKeys,
30 prefix_set::{PrefixSet, PrefixSetMut},
31};
32use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
33use reth_trie_sparse::provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory};
34use std::{
35 collections::VecDeque,
36 sync::{
37 atomic::{AtomicUsize, Ordering},
38 mpsc::{channel, Receiver, SendError, Sender},
39 Arc,
40 },
41 time::Instant,
42};
43use tokio::runtime::Handle;
44use tracing::trace;
45
46#[cfg(feature = "metrics")]
47use crate::proof_task_metrics::ProofTaskMetrics;
48
49type StorageProofResult = Result<DecodedStorageMultiProof, ParallelStateRootError>;
50type TrieNodeProviderResult = Result<Option<RevealedNode>, SparseTrieError>;
51
52#[derive(Debug)]
57enum StorageWorkerJob {
58 StorageProof {
60 input: StorageProofInput,
62 result_sender: Sender<StorageProofResult>,
64 },
65 BlindedStorageNode {
67 account: B256,
69 path: Nibbles,
71 result_sender: Sender<TrieNodeProviderResult>,
73 },
74}
75
76impl StorageWorkerJob {
77 fn send_worker_unavailable_error(&self) -> Result<(), ()> {
82 let error =
83 ParallelStateRootError::Other("Storage proof worker pool unavailable".to_string());
84
85 match self {
86 Self::StorageProof { result_sender, .. } => {
87 result_sender.send(Err(error)).map_err(|_| ())
88 }
89 Self::BlindedStorageNode { result_sender, .. } => result_sender
90 .send(Err(SparseTrieError::from(SparseTrieErrorKind::Other(Box::new(error)))))
91 .map_err(|_| ()),
92 }
93 }
94}
95
96#[derive(Debug)]
120pub struct ProofTaskManager<Factory: DatabaseProviderFactory> {
121 storage_work_tx: CrossbeamSender<StorageWorkerJob>,
123
124 storage_worker_count: usize,
128
129 max_concurrency: usize,
131
132 total_transactions: usize,
134
135 pending_tasks: VecDeque<ProofTaskKind>,
137
138 proof_task_txs: Vec<ProofTaskTx<FactoryTx<Factory>>>,
141
142 view: ConsistentDbView<Factory>,
144
145 task_ctx: ProofTaskCtx,
147
148 executor: Handle,
150
151 proof_task_rx: Receiver<ProofTaskMessage<FactoryTx<Factory>>>,
153
154 tx_sender: Sender<ProofTaskMessage<FactoryTx<Factory>>>,
156
157 active_handles: Arc<AtomicUsize>,
162
163 #[cfg(feature = "metrics")]
165 metrics: ProofTaskMetrics,
166}
167
168fn storage_worker_loop<Tx>(
192 proof_tx: ProofTaskTx<Tx>,
193 work_rx: CrossbeamReceiver<StorageWorkerJob>,
194 worker_id: usize,
195) where
196 Tx: DbTx,
197{
198 tracing::debug!(
199 target: "trie::proof_task",
200 worker_id,
201 "Storage worker started"
202 );
203
204 let (trie_cursor_factory, hashed_cursor_factory) = proof_tx.create_factories();
206
207 let blinded_provider_factory = ProofTrieNodeProviderFactory::new(
209 trie_cursor_factory.clone(),
210 hashed_cursor_factory.clone(),
211 proof_tx.task_ctx.prefix_sets.clone(),
212 );
213
214 let mut storage_proofs_processed = 0u64;
215 let mut storage_nodes_processed = 0u64;
216
217 while let Ok(job) = work_rx.recv() {
218 match job {
219 StorageWorkerJob::StorageProof { input, result_sender } => {
220 let hashed_address = input.hashed_address;
221
222 trace!(
223 target: "trie::proof_task",
224 worker_id,
225 hashed_address = ?hashed_address,
226 prefix_set_len = input.prefix_set.len(),
227 target_slots = input.target_slots.len(),
228 "Processing storage proof"
229 );
230
231 let proof_start = Instant::now();
232 let result = proof_tx.compute_storage_proof(
233 input,
234 trie_cursor_factory.clone(),
235 hashed_cursor_factory.clone(),
236 );
237
238 let proof_elapsed = proof_start.elapsed();
239 storage_proofs_processed += 1;
240
241 if result_sender.send(result).is_err() {
242 tracing::debug!(
243 target: "trie::proof_task",
244 worker_id,
245 hashed_address = ?hashed_address,
246 storage_proofs_processed,
247 "Storage proof receiver dropped, discarding result"
248 );
249 }
250
251 trace!(
252 target: "trie::proof_task",
253 worker_id,
254 hashed_address = ?hashed_address,
255 proof_time_us = proof_elapsed.as_micros(),
256 total_processed = storage_proofs_processed,
257 "Storage proof completed"
258 );
259 }
260
261 StorageWorkerJob::BlindedStorageNode { account, path, result_sender } => {
262 trace!(
263 target: "trie::proof_task",
264 worker_id,
265 ?account,
266 ?path,
267 "Processing blinded storage node"
268 );
269
270 let start = Instant::now();
271 let result =
272 blinded_provider_factory.storage_node_provider(account).trie_node(&path);
273 let elapsed = start.elapsed();
274
275 storage_nodes_processed += 1;
276
277 if result_sender.send(result).is_err() {
278 tracing::debug!(
279 target: "trie::proof_task",
280 worker_id,
281 ?account,
282 ?path,
283 storage_nodes_processed,
284 "Blinded storage node receiver dropped, discarding result"
285 );
286 }
287
288 trace!(
289 target: "trie::proof_task",
290 worker_id,
291 ?account,
292 ?path,
293 elapsed_us = elapsed.as_micros(),
294 total_processed = storage_nodes_processed,
295 "Blinded storage node completed"
296 );
297 }
298 }
299 }
300
301 tracing::debug!(
302 target: "trie::proof_task",
303 worker_id,
304 storage_proofs_processed,
305 storage_nodes_processed,
306 "Storage worker shutting down"
307 );
308}
309
310impl<Factory> ProofTaskManager<Factory>
311where
312 Factory: DatabaseProviderFactory<Provider: BlockReader>,
313{
314 pub fn new(
323 executor: Handle,
324 view: ConsistentDbView<Factory>,
325 task_ctx: ProofTaskCtx,
326 max_concurrency: usize,
327 storage_worker_count: usize,
328 ) -> ProviderResult<Self> {
329 let (tx_sender, proof_task_rx) = channel();
330
331 let (storage_work_tx, storage_work_rx) = unbounded::<StorageWorkerJob>();
334
335 tracing::info!(
336 target: "trie::proof_task",
337 storage_worker_count,
338 max_concurrency,
339 "Initializing storage worker pool with unbounded queue"
340 );
341
342 let mut spawned_workers = 0;
343 for worker_id in 0..storage_worker_count {
344 let provider_ro = view.provider_ro()?;
345
346 let tx = provider_ro.into_tx();
347 let proof_task_tx = ProofTaskTx::new(tx, task_ctx.clone(), worker_id);
348 let work_rx = storage_work_rx.clone();
349
350 executor.spawn_blocking(move || storage_worker_loop(proof_task_tx, work_rx, worker_id));
351
352 spawned_workers += 1;
353
354 tracing::debug!(
355 target: "trie::proof_task",
356 worker_id,
357 spawned_workers,
358 "Storage worker spawned successfully"
359 );
360 }
361
362 Ok(Self {
363 storage_work_tx,
364 storage_worker_count: spawned_workers,
365 max_concurrency,
366 total_transactions: 0,
367 pending_tasks: VecDeque::new(),
368 proof_task_txs: Vec::with_capacity(max_concurrency),
369 view,
370 task_ctx,
371 executor,
372 proof_task_rx,
373 tx_sender,
374 active_handles: Arc::new(AtomicUsize::new(0)),
375
376 #[cfg(feature = "metrics")]
377 metrics: ProofTaskMetrics::default(),
378 })
379 }
380
381 pub fn handle(&self) -> ProofTaskManagerHandle<FactoryTx<Factory>> {
383 ProofTaskManagerHandle::new(self.tx_sender.clone(), self.active_handles.clone())
384 }
385}
386
387impl<Factory> ProofTaskManager<Factory>
388where
389 Factory: DatabaseProviderFactory<Provider: BlockReader> + 'static,
390{
391 pub fn queue_proof_task(&mut self, task: ProofTaskKind) {
393 self.pending_tasks.push_back(task);
394 }
395
396 pub fn get_or_create_tx(&mut self) -> ProviderResult<Option<ProofTaskTx<FactoryTx<Factory>>>> {
399 if let Some(proof_task_tx) = self.proof_task_txs.pop() {
400 return Ok(Some(proof_task_tx));
401 }
402
403 if self.total_transactions < self.max_concurrency {
405 let provider_ro = self.view.provider_ro()?;
406 let tx = provider_ro.into_tx();
407 self.total_transactions += 1;
408 return Ok(Some(ProofTaskTx::new(tx, self.task_ctx.clone(), self.total_transactions)));
409 }
410
411 Ok(None)
412 }
413
414 pub fn try_spawn_next(&mut self) -> ProviderResult<()> {
420 let Some(task) = self.pending_tasks.pop_front() else { return Ok(()) };
421
422 let Some(proof_task_tx) = self.get_or_create_tx()? else {
423 self.pending_tasks.push_front(task);
425 return Ok(())
426 };
427
428 let tx_sender = self.tx_sender.clone();
429 self.executor.spawn_blocking(move || match task {
430 ProofTaskKind::BlindedAccountNode(path, sender) => {
431 proof_task_tx.blinded_account_node(path, sender, tx_sender);
432 }
433 ProofTaskKind::BlindedStorageNode(_, _, _) | ProofTaskKind::StorageProof(_, _) => {
435 unreachable!("Storage trie operations should be routed to worker pool")
436 }
437 });
438
439 Ok(())
440 }
441
442 pub fn run(mut self) -> ProviderResult<()> {
456 loop {
457 match self.proof_task_rx.recv() {
458 Ok(message) => {
459 match message {
460 ProofTaskMessage::QueueTask(task) => match task {
461 ProofTaskKind::StorageProof(input, sender) => {
462 match self.storage_work_tx.send(StorageWorkerJob::StorageProof {
463 input,
464 result_sender: sender,
465 }) {
466 Ok(_) => {
467 tracing::trace!(
468 target: "trie::proof_task",
469 "Storage proof dispatched to worker pool"
470 );
471 }
472 Err(crossbeam_channel::SendError(job)) => {
473 tracing::error!(
474 target: "trie::proof_task",
475 storage_worker_count = self.storage_worker_count,
476 "Worker pool disconnected, cannot process storage proof"
477 );
478
479 let _ = job.send_worker_unavailable_error();
481 }
482 }
483 }
484
485 ProofTaskKind::BlindedStorageNode(account, path, sender) => {
486 #[cfg(feature = "metrics")]
487 {
488 self.metrics.storage_nodes += 1;
489 }
490
491 match self.storage_work_tx.send(
492 StorageWorkerJob::BlindedStorageNode {
493 account,
494 path,
495 result_sender: sender,
496 },
497 ) {
498 Ok(_) => {
499 tracing::trace!(
500 target: "trie::proof_task",
501 ?account,
502 ?path,
503 "Blinded storage node dispatched to worker pool"
504 );
505 }
506 Err(crossbeam_channel::SendError(job)) => {
507 tracing::warn!(
508 target: "trie::proof_task",
509 storage_worker_count = self.storage_worker_count,
510 ?account,
511 ?path,
512 "Worker pool disconnected, cannot process blinded storage node"
513 );
514
515 let _ = job.send_worker_unavailable_error();
517 }
518 }
519 }
520
521 ProofTaskKind::BlindedAccountNode(_, _) => {
522 #[cfg(feature = "metrics")]
524 {
525 self.metrics.account_nodes += 1;
526 }
527 self.queue_proof_task(task);
528 }
529 },
530 ProofTaskMessage::Transaction(tx) => {
531 self.proof_task_txs.push(tx);
533 }
534 ProofTaskMessage::Terminate => {
535 drop(self.storage_work_tx);
537
538 tracing::debug!(
539 target: "trie::proof_task",
540 storage_worker_count = self.storage_worker_count,
541 "Shutting down proof task manager, signaling workers to terminate"
542 );
543
544 #[cfg(feature = "metrics")]
546 self.metrics.record();
547
548 return Ok(())
549 }
550 }
551 }
552 Err(_) => return Ok(()),
555 };
556
557 self.try_spawn_next()?;
559 }
560 }
561}
562
563type ProofFactories<'a, Tx> = (
565 InMemoryTrieCursorFactory<DatabaseTrieCursorFactory<&'a Tx>, &'a TrieUpdatesSorted>,
566 HashedPostStateCursorFactory<DatabaseHashedCursorFactory<&'a Tx>, &'a HashedPostStateSorted>,
567);
568
569#[derive(Debug)]
571pub struct ProofTaskTx<Tx> {
572 tx: Tx,
574
575 task_ctx: ProofTaskCtx,
577
578 id: usize,
581}
582
583impl<Tx> ProofTaskTx<Tx> {
584 const fn new(tx: Tx, task_ctx: ProofTaskCtx, id: usize) -> Self {
587 Self { tx, task_ctx, id }
588 }
589}
590
591impl<Tx> ProofTaskTx<Tx>
592where
593 Tx: DbTx,
594{
595 #[inline]
596 fn create_factories(&self) -> ProofFactories<'_, Tx> {
597 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
598 DatabaseTrieCursorFactory::new(&self.tx),
599 self.task_ctx.nodes_sorted.as_ref(),
600 );
601
602 let hashed_cursor_factory = HashedPostStateCursorFactory::new(
603 DatabaseHashedCursorFactory::new(&self.tx),
604 self.task_ctx.state_sorted.as_ref(),
605 );
606
607 (trie_cursor_factory, hashed_cursor_factory)
608 }
609
610 #[inline]
616 fn compute_storage_proof(
617 &self,
618 input: StorageProofInput,
619 trie_cursor_factory: impl TrieCursorFactory,
620 hashed_cursor_factory: impl HashedCursorFactory,
621 ) -> StorageProofResult {
622 let StorageProofInput {
624 hashed_address,
625 prefix_set,
626 target_slots,
627 with_branch_node_masks,
628 multi_added_removed_keys,
629 } = input;
630
631 let multi_added_removed_keys =
633 multi_added_removed_keys.unwrap_or_else(|| Arc::new(MultiAddedRemovedKeys::new()));
634 let added_removed_keys = multi_added_removed_keys.get_storage(&hashed_address);
635
636 let span = tracing::trace_span!(
637 target: "trie::proof_task",
638 "Storage proof calculation",
639 hashed_address = ?hashed_address,
640 worker_id = self.id,
641 );
642 let _span_guard = span.enter();
643
644 let proof_start = Instant::now();
645
646 let raw_proof_result =
648 StorageProof::new_hashed(trie_cursor_factory, hashed_cursor_factory, hashed_address)
649 .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().copied()))
650 .with_branch_node_masks(with_branch_node_masks)
651 .with_added_removed_keys(added_removed_keys)
652 .storage_multiproof(target_slots)
653 .map_err(|e| ParallelStateRootError::Other(e.to_string()));
654
655 let decoded_result = raw_proof_result.and_then(|raw_proof| {
657 raw_proof.try_into().map_err(|e: alloy_rlp::Error| {
658 ParallelStateRootError::Other(format!(
659 "Failed to decode storage proof for {}: {}",
660 hashed_address, e
661 ))
662 })
663 });
664
665 trace!(
666 target: "trie::proof_task",
667 hashed_address = ?hashed_address,
668 proof_time_us = proof_start.elapsed().as_micros(),
669 worker_id = self.id,
670 "Completed storage proof calculation"
671 );
672
673 decoded_result
674 }
675
676 fn blinded_account_node(
678 self,
679 path: Nibbles,
680 result_sender: Sender<TrieNodeProviderResult>,
681 tx_sender: Sender<ProofTaskMessage<Tx>>,
682 ) {
683 trace!(
684 target: "trie::proof_task",
685 ?path,
686 "Starting blinded account node retrieval"
687 );
688
689 let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
690
691 let blinded_provider_factory = ProofTrieNodeProviderFactory::new(
692 trie_cursor_factory,
693 hashed_cursor_factory,
694 self.task_ctx.prefix_sets.clone(),
695 );
696
697 let start = Instant::now();
698 let result = blinded_provider_factory.account_node_provider().trie_node(&path);
699 trace!(
700 target: "trie::proof_task",
701 ?path,
702 elapsed = ?start.elapsed(),
703 "Completed blinded account node retrieval"
704 );
705
706 if let Err(error) = result_sender.send(result) {
707 tracing::error!(
708 target: "trie::proof_task",
709 ?path,
710 ?error,
711 "Failed to send blinded account node result"
712 );
713 }
714
715 let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
717 }
718}
719
720#[derive(Debug)]
722pub struct StorageProofInput {
723 hashed_address: B256,
725 prefix_set: PrefixSet,
727 target_slots: B256Set,
729 with_branch_node_masks: bool,
731 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
733}
734
735impl StorageProofInput {
736 pub const fn new(
739 hashed_address: B256,
740 prefix_set: PrefixSet,
741 target_slots: B256Set,
742 with_branch_node_masks: bool,
743 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
744 ) -> Self {
745 Self {
746 hashed_address,
747 prefix_set,
748 target_slots,
749 with_branch_node_masks,
750 multi_added_removed_keys,
751 }
752 }
753}
754
755#[derive(Debug, Clone)]
757pub struct ProofTaskCtx {
758 nodes_sorted: Arc<TrieUpdatesSorted>,
761 state_sorted: Arc<HashedPostStateSorted>,
763 prefix_sets: Arc<TriePrefixSetsMut>,
767}
768
769impl ProofTaskCtx {
770 pub const fn new(
772 nodes_sorted: Arc<TrieUpdatesSorted>,
773 state_sorted: Arc<HashedPostStateSorted>,
774 prefix_sets: Arc<TriePrefixSetsMut>,
775 ) -> Self {
776 Self { nodes_sorted, state_sorted, prefix_sets }
777 }
778}
779
780#[derive(Debug)]
782pub enum ProofTaskMessage<Tx> {
783 QueueTask(ProofTaskKind),
785 Transaction(ProofTaskTx<Tx>),
787 Terminate,
789}
790
791#[derive(Debug)]
796pub enum ProofTaskKind {
797 StorageProof(StorageProofInput, Sender<StorageProofResult>),
799 BlindedAccountNode(Nibbles, Sender<TrieNodeProviderResult>),
801 BlindedStorageNode(B256, Nibbles, Sender<TrieNodeProviderResult>),
803}
804
805#[derive(Debug)]
808pub struct ProofTaskManagerHandle<Tx> {
809 sender: Sender<ProofTaskMessage<Tx>>,
811 active_handles: Arc<AtomicUsize>,
813}
814
815impl<Tx> ProofTaskManagerHandle<Tx> {
816 pub fn new(sender: Sender<ProofTaskMessage<Tx>>, active_handles: Arc<AtomicUsize>) -> Self {
818 active_handles.fetch_add(1, Ordering::SeqCst);
819 Self { sender, active_handles }
820 }
821
822 pub fn queue_task(&self, task: ProofTaskKind) -> Result<(), SendError<ProofTaskMessage<Tx>>> {
824 self.sender.send(ProofTaskMessage::QueueTask(task))
825 }
826
827 pub fn terminate(&self) {
829 let _ = self.sender.send(ProofTaskMessage::Terminate);
830 }
831}
832
833impl<Tx> Clone for ProofTaskManagerHandle<Tx> {
834 fn clone(&self) -> Self {
835 Self::new(self.sender.clone(), self.active_handles.clone())
836 }
837}
838
839impl<Tx> Drop for ProofTaskManagerHandle<Tx> {
840 fn drop(&mut self) {
841 if self.active_handles.fetch_sub(1, Ordering::SeqCst) == 1 {
844 self.terminate();
845 }
846 }
847}
848
849impl<Tx: DbTx> TrieNodeProviderFactory for ProofTaskManagerHandle<Tx> {
850 type AccountNodeProvider = ProofTaskTrieNodeProvider<Tx>;
851 type StorageNodeProvider = ProofTaskTrieNodeProvider<Tx>;
852
853 fn account_node_provider(&self) -> Self::AccountNodeProvider {
854 ProofTaskTrieNodeProvider::AccountNode { sender: self.sender.clone() }
855 }
856
857 fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
858 ProofTaskTrieNodeProvider::StorageNode { account, sender: self.sender.clone() }
859 }
860}
861
862#[derive(Debug)]
864pub enum ProofTaskTrieNodeProvider<Tx> {
865 AccountNode {
867 sender: Sender<ProofTaskMessage<Tx>>,
869 },
870 StorageNode {
872 account: B256,
874 sender: Sender<ProofTaskMessage<Tx>>,
876 },
877}
878
879impl<Tx: DbTx> TrieNodeProvider for ProofTaskTrieNodeProvider<Tx> {
880 fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
881 let (tx, rx) = channel();
882 match self {
883 Self::AccountNode { sender } => {
884 let _ = sender.send(ProofTaskMessage::QueueTask(
885 ProofTaskKind::BlindedAccountNode(*path, tx),
886 ));
887 }
888 Self::StorageNode { sender, account } => {
889 let _ = sender.send(ProofTaskMessage::QueueTask(
890 ProofTaskKind::BlindedStorageNode(*account, *path, tx),
891 ));
892 }
893 }
894
895 rx.recv().unwrap()
896 }
897}
898
899#[cfg(test)]
900mod tests {
901 use super::*;
902 use alloy_primitives::map::B256Map;
903 use reth_provider::{providers::ConsistentDbView, test_utils::create_test_provider_factory};
904 use reth_trie_common::{
905 prefix_set::TriePrefixSetsMut, updates::TrieUpdatesSorted, HashedAccountsSorted,
906 HashedPostStateSorted,
907 };
908 use std::sync::Arc;
909 use tokio::{runtime::Builder, task};
910
911 fn test_ctx() -> ProofTaskCtx {
912 ProofTaskCtx::new(
913 Arc::new(TrieUpdatesSorted::default()),
914 Arc::new(HashedPostStateSorted::new(
915 HashedAccountsSorted::default(),
916 B256Map::default(),
917 )),
918 Arc::new(TriePrefixSetsMut::default()),
919 )
920 }
921
922 #[test]
924 fn proof_task_manager_independent_pools() {
925 let runtime = Builder::new_multi_thread().worker_threads(1).enable_all().build().unwrap();
926 runtime.block_on(async {
927 let handle = tokio::runtime::Handle::current();
928 let factory = create_test_provider_factory();
929 let view = ConsistentDbView::new(factory, None);
930 let ctx = test_ctx();
931
932 let manager = ProofTaskManager::new(handle.clone(), view, ctx, 1, 5).unwrap();
933 assert_eq!(manager.storage_worker_count, 5);
935 assert_eq!(manager.max_concurrency, 1);
937
938 drop(manager);
939 task::yield_now().await;
940 });
941 }
942}