1use crate::{
33 root::ParallelStateRootError,
34 stats::{ParallelTrieStats, ParallelTrieTracker},
35 StorageRootTargets,
36};
37use alloy_primitives::{
38 map::{B256Map, B256Set},
39 B256,
40};
41use alloy_rlp::{BufMut, Encodable};
42use crossbeam_channel::{unbounded, Receiver as CrossbeamReceiver, Sender as CrossbeamSender};
43use dashmap::DashMap;
44use reth_execution_errors::{SparseTrieError, SparseTrieErrorKind};
45use reth_provider::{DatabaseProviderROFactory, ProviderError};
46use reth_storage_errors::db::DatabaseError;
47use reth_trie::{
48 hashed_cursor::HashedCursorFactory,
49 node_iter::{TrieElement, TrieNodeIter},
50 prefix_set::{TriePrefixSets, TriePrefixSetsMut},
51 proof::{ProofBlindedAccountProvider, ProofBlindedStorageProvider, StorageProof},
52 trie_cursor::TrieCursorFactory,
53 walker::TrieWalker,
54 DecodedMultiProof, DecodedStorageMultiProof, HashBuilder, HashedPostState, MultiProofTargets,
55 Nibbles, TRIE_ACCOUNT_RLP_MAX_SIZE,
56};
57use reth_trie_common::{
58 added_removed_keys::MultiAddedRemovedKeys,
59 prefix_set::{PrefixSet, PrefixSetMut},
60 proof::{DecodedProofNodes, ProofRetainer},
61};
62use reth_trie_sparse::provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory};
63use std::{
64 sync::{
65 atomic::{AtomicUsize, Ordering},
66 mpsc::{channel, Receiver, Sender},
67 Arc,
68 },
69 time::{Duration, Instant},
70};
71use tokio::runtime::Handle;
72use tracing::{debug, debug_span, error, trace};
73
74#[cfg(feature = "metrics")]
75use crate::proof_task_metrics::ProofTaskTrieMetrics;
76
77type StorageProofResult = Result<DecodedStorageMultiProof, ParallelStateRootError>;
78type TrieNodeProviderResult = Result<Option<RevealedNode>, SparseTrieError>;
79
80#[derive(Debug, Clone)]
86pub struct ProofWorkerHandle {
87 storage_work_tx: CrossbeamSender<StorageWorkerJob>,
89 account_work_tx: CrossbeamSender<AccountWorkerJob>,
91 storage_available_workers: Arc<AtomicUsize>,
94 account_available_workers: Arc<AtomicUsize>,
97 storage_worker_count: usize,
99 account_worker_count: usize,
101}
102
103impl ProofWorkerHandle {
104 pub fn new<Factory>(
115 executor: Handle,
116 task_ctx: ProofTaskCtx<Factory>,
117 storage_worker_count: usize,
118 account_worker_count: usize,
119 ) -> Self
120 where
121 Factory: DatabaseProviderROFactory<Provider: TrieCursorFactory + HashedCursorFactory>
122 + Clone
123 + Send
124 + 'static,
125 {
126 let (storage_work_tx, storage_work_rx) = unbounded::<StorageWorkerJob>();
127 let (account_work_tx, account_work_rx) = unbounded::<AccountWorkerJob>();
128
129 let storage_available_workers = Arc::new(AtomicUsize::new(0));
132 let account_available_workers = Arc::new(AtomicUsize::new(0));
133
134 debug!(
135 target: "trie::proof_task",
136 storage_worker_count,
137 account_worker_count,
138 "Spawning proof worker pools"
139 );
140
141 let parent_span =
142 debug_span!(target: "trie::proof_task", "storage proof workers", ?storage_worker_count)
143 .entered();
144 for worker_id in 0..storage_worker_count {
146 let span = debug_span!(target: "trie::proof_task", "storage worker", ?worker_id);
147 let task_ctx_clone = task_ctx.clone();
148 let work_rx_clone = storage_work_rx.clone();
149 let storage_available_workers_clone = storage_available_workers.clone();
150
151 executor.spawn_blocking(move || {
152 #[cfg(feature = "metrics")]
153 let metrics = ProofTaskTrieMetrics::default();
154
155 let _guard = span.enter();
156 let worker = StorageProofWorker::new(
157 task_ctx_clone,
158 work_rx_clone,
159 worker_id,
160 storage_available_workers_clone,
161 #[cfg(feature = "metrics")]
162 metrics,
163 );
164 worker.run()
165 });
166 }
167 drop(parent_span);
168
169 let parent_span =
170 debug_span!(target: "trie::proof_task", "account proof workers", ?storage_worker_count)
171 .entered();
172 for worker_id in 0..account_worker_count {
174 let span = debug_span!(target: "trie::proof_task", "account worker", ?worker_id);
175 let task_ctx_clone = task_ctx.clone();
176 let work_rx_clone = account_work_rx.clone();
177 let storage_work_tx_clone = storage_work_tx.clone();
178 let account_available_workers_clone = account_available_workers.clone();
179
180 executor.spawn_blocking(move || {
181 #[cfg(feature = "metrics")]
182 let metrics = ProofTaskTrieMetrics::default();
183
184 let _guard = span.enter();
185 let worker = AccountProofWorker::new(
186 task_ctx_clone,
187 work_rx_clone,
188 worker_id,
189 storage_work_tx_clone,
190 account_available_workers_clone,
191 #[cfg(feature = "metrics")]
192 metrics,
193 );
194 worker.run()
195 });
196 }
197 drop(parent_span);
198
199 Self {
200 storage_work_tx,
201 account_work_tx,
202 storage_available_workers,
203 account_available_workers,
204 storage_worker_count,
205 account_worker_count,
206 }
207 }
208
209 pub fn available_storage_workers(&self) -> usize {
211 self.storage_available_workers.load(Ordering::Relaxed)
212 }
213
214 pub fn available_account_workers(&self) -> usize {
216 self.account_available_workers.load(Ordering::Relaxed)
217 }
218
219 pub fn pending_storage_tasks(&self) -> usize {
221 self.storage_work_tx.len()
222 }
223
224 pub fn pending_account_tasks(&self) -> usize {
226 self.account_work_tx.len()
227 }
228
229 pub const fn total_storage_workers(&self) -> usize {
231 self.storage_worker_count
232 }
233
234 pub const fn total_account_workers(&self) -> usize {
236 self.account_worker_count
237 }
238
239 pub fn active_storage_workers(&self) -> usize {
243 self.storage_worker_count.saturating_sub(self.available_storage_workers())
244 }
245
246 pub fn active_account_workers(&self) -> usize {
250 self.account_worker_count.saturating_sub(self.available_account_workers())
251 }
252
253 pub fn dispatch_storage_proof(
257 &self,
258 input: StorageProofInput,
259 proof_result_sender: ProofResultContext,
260 ) -> Result<(), ProviderError> {
261 self.storage_work_tx
262 .send(StorageWorkerJob::StorageProof { input, proof_result_sender })
263 .map_err(|err| {
264 let error =
265 ProviderError::other(std::io::Error::other("storage workers unavailable"));
266
267 if let StorageWorkerJob::StorageProof { proof_result_sender, .. } = err.0 {
268 let ProofResultContext {
269 sender: result_tx,
270 sequence_number: seq,
271 state,
272 start_time: start,
273 } = proof_result_sender;
274
275 let _ = result_tx.send(ProofResultMessage {
276 sequence_number: seq,
277 result: Err(ParallelStateRootError::Provider(error.clone())),
278 elapsed: start.elapsed(),
279 state,
280 });
281 }
282
283 error
284 })
285 }
286
287 pub fn dispatch_account_multiproof(
291 &self,
292 input: AccountMultiproofInput,
293 ) -> Result<(), ProviderError> {
294 self.account_work_tx
295 .send(AccountWorkerJob::AccountMultiproof { input: Box::new(input) })
296 .map_err(|err| {
297 let error =
298 ProviderError::other(std::io::Error::other("account workers unavailable"));
299
300 if let AccountWorkerJob::AccountMultiproof { input } = err.0 {
301 let AccountMultiproofInput {
302 proof_result_sender:
303 ProofResultContext {
304 sender: result_tx,
305 sequence_number: seq,
306 state,
307 start_time: start,
308 },
309 ..
310 } = *input;
311
312 let _ = result_tx.send(ProofResultMessage {
313 sequence_number: seq,
314 result: Err(ParallelStateRootError::Provider(error.clone())),
315 elapsed: start.elapsed(),
316 state,
317 });
318 }
319
320 error
321 })
322 }
323
324 pub(crate) fn dispatch_blinded_storage_node(
326 &self,
327 account: B256,
328 path: Nibbles,
329 ) -> Result<Receiver<TrieNodeProviderResult>, ProviderError> {
330 let (tx, rx) = channel();
331 self.storage_work_tx
332 .send(StorageWorkerJob::BlindedStorageNode { account, path, result_sender: tx })
333 .map_err(|_| {
334 ProviderError::other(std::io::Error::other("storage workers unavailable"))
335 })?;
336
337 Ok(rx)
338 }
339
340 pub(crate) fn dispatch_blinded_account_node(
342 &self,
343 path: Nibbles,
344 ) -> Result<Receiver<TrieNodeProviderResult>, ProviderError> {
345 let (tx, rx) = channel();
346 self.account_work_tx
347 .send(AccountWorkerJob::BlindedAccountNode { path, result_sender: tx })
348 .map_err(|_| {
349 ProviderError::other(std::io::Error::other("account workers unavailable"))
350 })?;
351
352 Ok(rx)
353 }
354}
355
356#[derive(Clone, Debug)]
358pub struct ProofTaskCtx<Factory> {
359 factory: Factory,
361 prefix_sets: Arc<TriePrefixSetsMut>,
365}
366
367impl<Factory> ProofTaskCtx<Factory> {
368 pub const fn new(factory: Factory, prefix_sets: Arc<TriePrefixSetsMut>) -> Self {
370 Self { factory, prefix_sets }
371 }
372}
373
374#[derive(Debug)]
376pub struct ProofTaskTx<Provider> {
377 provider: Provider,
379
380 prefix_sets: Arc<TriePrefixSetsMut>,
382
383 id: usize,
385}
386
387impl<Provider> ProofTaskTx<Provider> {
388 const fn new(provider: Provider, prefix_sets: Arc<TriePrefixSetsMut>, id: usize) -> Self {
390 Self { provider, prefix_sets, id }
391 }
392}
393
394impl<Provider> ProofTaskTx<Provider>
395where
396 Provider: TrieCursorFactory + HashedCursorFactory,
397{
398 #[inline]
402 fn compute_storage_proof(&self, input: StorageProofInput) -> StorageProofResult {
403 let StorageProofInput {
405 hashed_address,
406 prefix_set,
407 target_slots,
408 with_branch_node_masks,
409 multi_added_removed_keys,
410 } = input;
411
412 let multi_added_removed_keys =
414 multi_added_removed_keys.unwrap_or_else(|| Arc::new(MultiAddedRemovedKeys::new()));
415 let added_removed_keys = multi_added_removed_keys.get_storage(&hashed_address);
416
417 let span = debug_span!(
418 target: "trie::proof_task",
419 "Storage proof calculation",
420 hashed_address = ?hashed_address,
421 worker_id = self.id,
422 );
423 let _span_guard = span.enter();
424
425 let proof_start = Instant::now();
426
427 let raw_proof_result =
429 StorageProof::new_hashed(&self.provider, &self.provider, hashed_address)
430 .with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().copied()))
431 .with_branch_node_masks(with_branch_node_masks)
432 .with_added_removed_keys(added_removed_keys)
433 .storage_multiproof(target_slots)
434 .map_err(|e| ParallelStateRootError::Other(e.to_string()));
435
436 let decoded_result = raw_proof_result.and_then(|raw_proof| {
438 raw_proof.try_into().map_err(|e: alloy_rlp::Error| {
439 ParallelStateRootError::Other(format!(
440 "Failed to decode storage proof for {}: {}",
441 hashed_address, e
442 ))
443 })
444 });
445
446 trace!(
447 target: "trie::proof_task",
448 hashed_address = ?hashed_address,
449 proof_time_us = proof_start.elapsed().as_micros(),
450 worker_id = self.id,
451 "Completed storage proof calculation"
452 );
453
454 decoded_result
455 }
456
457 fn process_blinded_storage_node(
461 &self,
462 account: B256,
463 path: &Nibbles,
464 ) -> TrieNodeProviderResult {
465 let storage_node_provider = ProofBlindedStorageProvider::new(
466 &self.provider,
467 &self.provider,
468 self.prefix_sets.clone(),
469 account,
470 );
471 storage_node_provider.trie_node(path)
472 }
473
474 fn process_blinded_account_node(&self, path: &Nibbles) -> TrieNodeProviderResult {
478 let account_node_provider = ProofBlindedAccountProvider::new(
479 &self.provider,
480 &self.provider,
481 self.prefix_sets.clone(),
482 );
483 account_node_provider.trie_node(path)
484 }
485}
486impl TrieNodeProviderFactory for ProofWorkerHandle {
487 type AccountNodeProvider = ProofTaskTrieNodeProvider;
488 type StorageNodeProvider = ProofTaskTrieNodeProvider;
489
490 fn account_node_provider(&self) -> Self::AccountNodeProvider {
491 ProofTaskTrieNodeProvider::AccountNode { handle: self.clone() }
492 }
493
494 fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
495 ProofTaskTrieNodeProvider::StorageNode { account, handle: self.clone() }
496 }
497}
498
499#[derive(Debug)]
501pub enum ProofTaskTrieNodeProvider {
502 AccountNode {
504 handle: ProofWorkerHandle,
506 },
507 StorageNode {
509 account: B256,
511 handle: ProofWorkerHandle,
513 },
514}
515
516impl TrieNodeProvider for ProofTaskTrieNodeProvider {
517 fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
518 match self {
519 Self::AccountNode { handle } => {
520 let rx = handle
521 .dispatch_blinded_account_node(*path)
522 .map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
523 rx.recv().map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?
524 }
525 Self::StorageNode { handle, account } => {
526 let rx = handle
527 .dispatch_blinded_storage_node(*account, *path)
528 .map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?;
529 rx.recv().map_err(|error| SparseTrieErrorKind::Other(Box::new(error)))?
530 }
531 }
532 }
533}
534#[derive(Debug)]
536pub enum ProofResult {
537 AccountMultiproof {
539 proof: DecodedMultiProof,
541 stats: ParallelTrieStats,
543 },
544 StorageProof {
546 hashed_address: B256,
548 proof: DecodedStorageMultiProof,
550 },
551}
552
553impl ProofResult {
554 pub fn into_multiproof(self) -> DecodedMultiProof {
559 match self {
560 Self::AccountMultiproof { proof, stats: _ } => proof,
561 Self::StorageProof { hashed_address, proof } => {
562 DecodedMultiProof::from_storage_proof(hashed_address, proof)
563 }
564 }
565 }
566}
567pub type ProofResultSender = CrossbeamSender<ProofResultMessage>;
572
573#[derive(Debug)]
578pub struct ProofResultMessage {
579 pub sequence_number: u64,
581 pub result: Result<ProofResult, ParallelStateRootError>,
583 pub elapsed: Duration,
585 pub state: HashedPostState,
587}
588
589#[derive(Debug, Clone)]
594pub struct ProofResultContext {
595 pub sender: ProofResultSender,
597 pub sequence_number: u64,
599 pub state: HashedPostState,
601 pub start_time: Instant,
603}
604
605impl ProofResultContext {
606 pub const fn new(
608 sender: ProofResultSender,
609 sequence_number: u64,
610 state: HashedPostState,
611 start_time: Instant,
612 ) -> Self {
613 Self { sender, sequence_number, state, start_time }
614 }
615}
616#[derive(Debug)]
618enum StorageWorkerJob {
619 StorageProof {
621 input: StorageProofInput,
623 proof_result_sender: ProofResultContext,
625 },
626 BlindedStorageNode {
628 account: B256,
630 path: Nibbles,
632 result_sender: Sender<TrieNodeProviderResult>,
634 },
635}
636
637struct StorageProofWorker<Factory> {
642 task_ctx: ProofTaskCtx<Factory>,
644 work_rx: CrossbeamReceiver<StorageWorkerJob>,
646 worker_id: usize,
648 available_workers: Arc<AtomicUsize>,
650 #[cfg(feature = "metrics")]
652 metrics: ProofTaskTrieMetrics,
653}
654
655impl<Factory> StorageProofWorker<Factory>
656where
657 Factory: DatabaseProviderROFactory<Provider: TrieCursorFactory + HashedCursorFactory>,
658{
659 const fn new(
661 task_ctx: ProofTaskCtx<Factory>,
662 work_rx: CrossbeamReceiver<StorageWorkerJob>,
663 worker_id: usize,
664 available_workers: Arc<AtomicUsize>,
665 #[cfg(feature = "metrics")] metrics: ProofTaskTrieMetrics,
666 ) -> Self {
667 Self {
668 task_ctx,
669 work_rx,
670 worker_id,
671 available_workers,
672 #[cfg(feature = "metrics")]
673 metrics,
674 }
675 }
676
677 fn run(self) {
695 let Self {
696 task_ctx,
697 work_rx,
698 worker_id,
699 available_workers,
700 #[cfg(feature = "metrics")]
701 metrics,
702 } = self;
703
704 let provider = task_ctx
706 .factory
707 .database_provider_ro()
708 .expect("Storage worker failed to initialize: unable to create provider");
709 let proof_tx = ProofTaskTx::new(provider, task_ctx.prefix_sets, worker_id);
710
711 trace!(
712 target: "trie::proof_task",
713 worker_id,
714 "Storage worker started"
715 );
716
717 let mut storage_proofs_processed = 0u64;
718 let mut storage_nodes_processed = 0u64;
719
720 available_workers.fetch_add(1, Ordering::Relaxed);
722
723 while let Ok(job) = work_rx.recv() {
724 available_workers.fetch_sub(1, Ordering::Relaxed);
726
727 match job {
728 StorageWorkerJob::StorageProof { input, proof_result_sender } => {
729 Self::process_storage_proof(
730 worker_id,
731 &proof_tx,
732 input,
733 proof_result_sender,
734 &mut storage_proofs_processed,
735 );
736 }
737
738 StorageWorkerJob::BlindedStorageNode { account, path, result_sender } => {
739 Self::process_blinded_node(
740 worker_id,
741 &proof_tx,
742 account,
743 path,
744 result_sender,
745 &mut storage_nodes_processed,
746 );
747 }
748 }
749
750 available_workers.fetch_add(1, Ordering::Relaxed);
752 }
753
754 trace!(
755 target: "trie::proof_task",
756 worker_id,
757 storage_proofs_processed,
758 storage_nodes_processed,
759 "Storage worker shutting down"
760 );
761
762 #[cfg(feature = "metrics")]
763 metrics.record_storage_nodes(storage_nodes_processed as usize);
764 }
765
766 fn process_storage_proof<Provider>(
768 worker_id: usize,
769 proof_tx: &ProofTaskTx<Provider>,
770 input: StorageProofInput,
771 proof_result_sender: ProofResultContext,
772 storage_proofs_processed: &mut u64,
773 ) where
774 Provider: TrieCursorFactory + HashedCursorFactory,
775 {
776 let hashed_address = input.hashed_address;
777 let ProofResultContext { sender, sequence_number: seq, state, start_time } =
778 proof_result_sender;
779
780 trace!(
781 target: "trie::proof_task",
782 worker_id,
783 hashed_address = ?hashed_address,
784 prefix_set_len = input.prefix_set.len(),
785 target_slots_len = input.target_slots.len(),
786 "Processing storage proof"
787 );
788
789 let proof_start = Instant::now();
790 let result = proof_tx.compute_storage_proof(input);
791
792 let proof_elapsed = proof_start.elapsed();
793 *storage_proofs_processed += 1;
794
795 let result_msg = result.map(|storage_proof| ProofResult::StorageProof {
796 hashed_address,
797 proof: storage_proof,
798 });
799
800 if sender
801 .send(ProofResultMessage {
802 sequence_number: seq,
803 result: result_msg,
804 elapsed: start_time.elapsed(),
805 state,
806 })
807 .is_err()
808 {
809 trace!(
810 target: "trie::proof_task",
811 worker_id,
812 hashed_address = ?hashed_address,
813 storage_proofs_processed,
814 "Proof result receiver dropped, discarding result"
815 );
816 }
817
818 trace!(
819 target: "trie::proof_task",
820 worker_id,
821 hashed_address = ?hashed_address,
822 proof_time_us = proof_elapsed.as_micros(),
823 total_processed = storage_proofs_processed,
824 "Storage proof completed"
825 );
826 }
827
828 fn process_blinded_node<Provider>(
830 worker_id: usize,
831 proof_tx: &ProofTaskTx<Provider>,
832 account: B256,
833 path: Nibbles,
834 result_sender: Sender<TrieNodeProviderResult>,
835 storage_nodes_processed: &mut u64,
836 ) where
837 Provider: TrieCursorFactory + HashedCursorFactory,
838 {
839 trace!(
840 target: "trie::proof_task",
841 worker_id,
842 ?account,
843 ?path,
844 "Processing blinded storage node"
845 );
846
847 let start = Instant::now();
848 let result = proof_tx.process_blinded_storage_node(account, &path);
849 let elapsed = start.elapsed();
850
851 *storage_nodes_processed += 1;
852
853 if result_sender.send(result).is_err() {
854 trace!(
855 target: "trie::proof_task",
856 worker_id,
857 ?account,
858 ?path,
859 storage_nodes_processed,
860 "Blinded storage node receiver dropped, discarding result"
861 );
862 }
863
864 trace!(
865 target: "trie::proof_task",
866 worker_id,
867 ?account,
868 ?path,
869 elapsed_us = elapsed.as_micros(),
870 total_processed = storage_nodes_processed,
871 "Blinded storage node completed"
872 );
873 }
874}
875
876struct AccountProofWorker<Factory> {
881 task_ctx: ProofTaskCtx<Factory>,
883 work_rx: CrossbeamReceiver<AccountWorkerJob>,
885 worker_id: usize,
887 storage_work_tx: CrossbeamSender<StorageWorkerJob>,
889 available_workers: Arc<AtomicUsize>,
891 #[cfg(feature = "metrics")]
893 metrics: ProofTaskTrieMetrics,
894}
895
896impl<Factory> AccountProofWorker<Factory>
897where
898 Factory: DatabaseProviderROFactory<Provider: TrieCursorFactory + HashedCursorFactory>,
899{
900 const fn new(
902 task_ctx: ProofTaskCtx<Factory>,
903 work_rx: CrossbeamReceiver<AccountWorkerJob>,
904 worker_id: usize,
905 storage_work_tx: CrossbeamSender<StorageWorkerJob>,
906 available_workers: Arc<AtomicUsize>,
907 #[cfg(feature = "metrics")] metrics: ProofTaskTrieMetrics,
908 ) -> Self {
909 Self {
910 task_ctx,
911 work_rx,
912 worker_id,
913 storage_work_tx,
914 available_workers,
915 #[cfg(feature = "metrics")]
916 metrics,
917 }
918 }
919
920 fn run(self) {
938 let Self {
939 task_ctx,
940 work_rx,
941 worker_id,
942 storage_work_tx,
943 available_workers,
944 #[cfg(feature = "metrics")]
945 metrics,
946 } = self;
947
948 let provider = task_ctx
950 .factory
951 .database_provider_ro()
952 .expect("Account worker failed to initialize: unable to create provider");
953 let proof_tx = ProofTaskTx::new(provider, task_ctx.prefix_sets, worker_id);
954
955 trace!(
956 target: "trie::proof_task",
957 worker_id,
958 "Account worker started"
959 );
960
961 let mut account_proofs_processed = 0u64;
962 let mut account_nodes_processed = 0u64;
963
964 available_workers.fetch_add(1, Ordering::Relaxed);
966
967 while let Ok(job) = work_rx.recv() {
968 available_workers.fetch_sub(1, Ordering::Relaxed);
970
971 match job {
972 AccountWorkerJob::AccountMultiproof { input } => {
973 Self::process_account_multiproof(
974 worker_id,
975 &proof_tx,
976 storage_work_tx.clone(),
977 *input,
978 &mut account_proofs_processed,
979 );
980 }
981
982 AccountWorkerJob::BlindedAccountNode { path, result_sender } => {
983 Self::process_blinded_node(
984 worker_id,
985 &proof_tx,
986 path,
987 result_sender,
988 &mut account_nodes_processed,
989 );
990 }
991 }
992
993 available_workers.fetch_add(1, Ordering::Relaxed);
995 }
996
997 trace!(
998 target: "trie::proof_task",
999 worker_id,
1000 account_proofs_processed,
1001 account_nodes_processed,
1002 "Account worker shutting down"
1003 );
1004
1005 #[cfg(feature = "metrics")]
1006 metrics.record_account_nodes(account_nodes_processed as usize);
1007 }
1008
1009 fn process_account_multiproof<Provider>(
1011 worker_id: usize,
1012 proof_tx: &ProofTaskTx<Provider>,
1013 storage_work_tx: CrossbeamSender<StorageWorkerJob>,
1014 input: AccountMultiproofInput,
1015 account_proofs_processed: &mut u64,
1016 ) where
1017 Provider: TrieCursorFactory + HashedCursorFactory,
1018 {
1019 let AccountMultiproofInput {
1020 targets,
1021 mut prefix_sets,
1022 collect_branch_node_masks,
1023 multi_added_removed_keys,
1024 missed_leaves_storage_roots,
1025 proof_result_sender:
1026 ProofResultContext { sender: result_tx, sequence_number: seq, state, start_time: start },
1027 } = input;
1028
1029 let span = debug_span!(
1030 target: "trie::proof_task",
1031 "Account multiproof calculation",
1032 targets = targets.len(),
1033 worker_id,
1034 );
1035 let _span_guard = span.enter();
1036
1037 trace!(
1038 target: "trie::proof_task",
1039 "Processing account multiproof"
1040 );
1041
1042 let proof_start = Instant::now();
1043
1044 let mut tracker = ParallelTrieTracker::default();
1045
1046 let mut storage_prefix_sets = std::mem::take(&mut prefix_sets.storage_prefix_sets);
1047
1048 let storage_root_targets_len =
1049 StorageRootTargets::count(&prefix_sets.account_prefix_set, &storage_prefix_sets);
1050
1051 tracker.set_precomputed_storage_roots(storage_root_targets_len as u64);
1052
1053 let storage_proof_receivers = match dispatch_storage_proofs(
1054 &storage_work_tx,
1055 &targets,
1056 &mut storage_prefix_sets,
1057 collect_branch_node_masks,
1058 multi_added_removed_keys.as_ref(),
1059 ) {
1060 Ok(receivers) => receivers,
1061 Err(error) => {
1062 error!(target: "trie::proof_task", "Failed to dispatch storage proofs: {error}");
1064 let _ = result_tx.send(ProofResultMessage {
1065 sequence_number: seq,
1066 result: Err(error),
1067 elapsed: start.elapsed(),
1068 state,
1069 });
1070 return;
1071 }
1072 };
1073
1074 let account_prefix_set = std::mem::take(&mut prefix_sets.account_prefix_set);
1076
1077 let ctx = AccountMultiproofParams {
1078 targets: &targets,
1079 prefix_set: account_prefix_set,
1080 collect_branch_node_masks,
1081 multi_added_removed_keys: multi_added_removed_keys.as_ref(),
1082 storage_proof_receivers,
1083 missed_leaves_storage_roots: missed_leaves_storage_roots.as_ref(),
1084 };
1085
1086 let result =
1087 build_account_multiproof_with_storage_roots(&proof_tx.provider, ctx, &mut tracker);
1088
1089 let proof_elapsed = proof_start.elapsed();
1090 let total_elapsed = start.elapsed();
1091 let stats = tracker.finish();
1092 let result = result.map(|proof| ProofResult::AccountMultiproof { proof, stats });
1093 *account_proofs_processed += 1;
1094
1095 if result_tx
1097 .send(ProofResultMessage {
1098 sequence_number: seq,
1099 result,
1100 elapsed: total_elapsed,
1101 state,
1102 })
1103 .is_err()
1104 {
1105 trace!(
1106 target: "trie::proof_task",
1107 worker_id,
1108 account_proofs_processed,
1109 "Account multiproof receiver dropped, discarding result"
1110 );
1111 }
1112
1113 trace!(
1114 target: "trie::proof_task",
1115 proof_time_us = proof_elapsed.as_micros(),
1116 total_elapsed_us = total_elapsed.as_micros(),
1117 total_processed = account_proofs_processed,
1118 "Account multiproof completed"
1119 );
1120 }
1121
1122 fn process_blinded_node<Provider>(
1124 worker_id: usize,
1125 proof_tx: &ProofTaskTx<Provider>,
1126 path: Nibbles,
1127 result_sender: Sender<TrieNodeProviderResult>,
1128 account_nodes_processed: &mut u64,
1129 ) where
1130 Provider: TrieCursorFactory + HashedCursorFactory,
1131 {
1132 let span = debug_span!(
1133 target: "trie::proof_task",
1134 "Blinded account node calculation",
1135 ?path,
1136 worker_id,
1137 );
1138 let _span_guard = span.enter();
1139
1140 trace!(
1141 target: "trie::proof_task",
1142 "Processing blinded account node"
1143 );
1144
1145 let start = Instant::now();
1146 let result = proof_tx.process_blinded_account_node(&path);
1147 let elapsed = start.elapsed();
1148
1149 *account_nodes_processed += 1;
1150
1151 if result_sender.send(result).is_err() {
1152 trace!(
1153 target: "trie::proof_task",
1154 worker_id,
1155 ?path,
1156 account_nodes_processed,
1157 "Blinded account node receiver dropped, discarding result"
1158 );
1159 }
1160
1161 trace!(
1162 target: "trie::proof_task",
1163 node_time_us = elapsed.as_micros(),
1164 total_processed = account_nodes_processed,
1165 "Blinded account node completed"
1166 );
1167 }
1168}
1169
1170fn build_account_multiproof_with_storage_roots<P>(
1178 provider: &P,
1179 ctx: AccountMultiproofParams<'_>,
1180 tracker: &mut ParallelTrieTracker,
1181) -> Result<DecodedMultiProof, ParallelStateRootError>
1182where
1183 P: TrieCursorFactory + HashedCursorFactory,
1184{
1185 let accounts_added_removed_keys =
1186 ctx.multi_added_removed_keys.as_ref().map(|keys| keys.get_accounts());
1187
1188 let walker = TrieWalker::<_>::state_trie(
1190 provider.account_trie_cursor().map_err(ProviderError::Database)?,
1191 ctx.prefix_set,
1192 )
1193 .with_added_removed_keys(accounts_added_removed_keys)
1194 .with_deletions_retained(true);
1195
1196 let retainer = ctx
1198 .targets
1199 .keys()
1200 .map(Nibbles::unpack)
1201 .collect::<ProofRetainer>()
1202 .with_added_removed_keys(accounts_added_removed_keys);
1203 let mut hash_builder = HashBuilder::default()
1204 .with_proof_retainer(retainer)
1205 .with_updates(ctx.collect_branch_node_masks);
1206
1207 let mut collected_decoded_storages: B256Map<DecodedStorageMultiProof> =
1210 B256Map::with_capacity_and_hasher(ctx.targets.len(), Default::default());
1211 let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
1212 let mut account_node_iter = TrieNodeIter::state_trie(
1213 walker,
1214 provider.hashed_account_cursor().map_err(ProviderError::Database)?,
1215 );
1216
1217 let mut storage_proof_receivers = ctx.storage_proof_receivers;
1218
1219 while let Some(account_node) = account_node_iter.try_next().map_err(ProviderError::Database)? {
1220 match account_node {
1221 TrieElement::Branch(node) => {
1222 hash_builder.add_branch(node.key, node.value, node.children_are_in_trie);
1223 }
1224 TrieElement::Leaf(hashed_address, account) => {
1225 let root = match storage_proof_receivers.remove(&hashed_address) {
1226 Some(receiver) => {
1227 let proof_msg = receiver.recv().map_err(|_| {
1230 ParallelStateRootError::StorageRoot(
1231 reth_execution_errors::StorageRootError::Database(
1232 DatabaseError::Other(format!(
1233 "Storage proof channel closed for {hashed_address}"
1234 )),
1235 ),
1236 )
1237 })?;
1238
1239 let proof = match proof_msg.result? {
1241 ProofResult::StorageProof { hashed_address: addr, proof } => {
1242 debug_assert_eq!(
1243 addr,
1244 hashed_address,
1245 "storage worker must return same address: expected {hashed_address}, got {addr}"
1246 );
1247 proof
1248 }
1249 ProofResult::AccountMultiproof { .. } => {
1250 unreachable!("storage worker only sends StorageProof variant")
1251 }
1252 };
1253
1254 let root = proof.root;
1255 collected_decoded_storages.insert(hashed_address, proof);
1256 root
1257 }
1258 None => {
1261 tracker.inc_missed_leaves();
1262
1263 match ctx.missed_leaves_storage_roots.entry(hashed_address) {
1264 dashmap::Entry::Occupied(occ) => *occ.get(),
1265 dashmap::Entry::Vacant(vac) => {
1266 let root =
1267 StorageProof::new_hashed(provider, provider, hashed_address)
1268 .with_prefix_set_mut(Default::default())
1269 .storage_multiproof(
1270 ctx.targets
1271 .get(&hashed_address)
1272 .cloned()
1273 .unwrap_or_default(),
1274 )
1275 .map_err(|e| {
1276 ParallelStateRootError::StorageRoot(
1277 reth_execution_errors::StorageRootError::Database(
1278 DatabaseError::Other(e.to_string()),
1279 ),
1280 )
1281 })?
1282 .root;
1283
1284 vac.insert(root);
1285 root
1286 }
1287 }
1288 }
1289 };
1290
1291 account_rlp.clear();
1293 let account = account.into_trie_account(root);
1294 account.encode(&mut account_rlp as &mut dyn BufMut);
1295
1296 hash_builder.add_leaf(Nibbles::unpack(hashed_address), &account_rlp);
1297 }
1298 }
1299 }
1300
1301 for (hashed_address, receiver) in storage_proof_receivers {
1303 if let Ok(proof_msg) = receiver.recv() {
1304 if let Ok(ProofResult::StorageProof { proof, .. }) = proof_msg.result {
1306 collected_decoded_storages.insert(hashed_address, proof);
1307 }
1308 }
1309 }
1310
1311 let _ = hash_builder.root();
1312
1313 let account_subtree_raw_nodes = hash_builder.take_proof_nodes();
1314 let decoded_account_subtree = DecodedProofNodes::try_from(account_subtree_raw_nodes)?;
1315
1316 let (branch_node_hash_masks, branch_node_tree_masks) = if ctx.collect_branch_node_masks {
1317 let updated_branch_nodes = hash_builder.updated_branch_nodes.unwrap_or_default();
1318 (
1319 updated_branch_nodes.iter().map(|(path, node)| (*path, node.hash_mask)).collect(),
1320 updated_branch_nodes.into_iter().map(|(path, node)| (path, node.tree_mask)).collect(),
1321 )
1322 } else {
1323 (Default::default(), Default::default())
1324 };
1325
1326 Ok(DecodedMultiProof {
1327 account_subtree: decoded_account_subtree,
1328 branch_node_hash_masks,
1329 branch_node_tree_masks,
1330 storages: collected_decoded_storages,
1331 })
1332}
1333fn dispatch_storage_proofs(
1341 storage_work_tx: &CrossbeamSender<StorageWorkerJob>,
1342 targets: &MultiProofTargets,
1343 storage_prefix_sets: &mut B256Map<PrefixSet>,
1344 with_branch_node_masks: bool,
1345 multi_added_removed_keys: Option<&Arc<MultiAddedRemovedKeys>>,
1346) -> Result<B256Map<CrossbeamReceiver<ProofResultMessage>>, ParallelStateRootError> {
1347 let mut storage_proof_receivers =
1348 B256Map::with_capacity_and_hasher(targets.len(), Default::default());
1349
1350 for (hashed_address, target_slots) in targets.iter() {
1352 let prefix_set = storage_prefix_sets.remove(hashed_address).unwrap_or_default();
1353
1354 let (result_tx, result_rx) = crossbeam_channel::unbounded();
1356 let start = Instant::now();
1357
1358 let input = StorageProofInput::new(
1360 *hashed_address,
1361 prefix_set,
1362 target_slots.clone(),
1363 with_branch_node_masks,
1364 multi_added_removed_keys.cloned(),
1365 );
1366
1367 storage_work_tx
1370 .send(StorageWorkerJob::StorageProof {
1371 input,
1372 proof_result_sender: ProofResultContext::new(
1373 result_tx,
1374 0,
1375 HashedPostState::default(),
1376 start,
1377 ),
1378 })
1379 .map_err(|_| {
1380 ParallelStateRootError::Other(format!(
1381 "Failed to queue storage proof for {}: storage worker pool unavailable",
1382 hashed_address
1383 ))
1384 })?;
1385
1386 storage_proof_receivers.insert(*hashed_address, result_rx);
1387 }
1388
1389 Ok(storage_proof_receivers)
1390}
1391#[derive(Debug)]
1393pub struct StorageProofInput {
1394 hashed_address: B256,
1396 prefix_set: PrefixSet,
1398 target_slots: B256Set,
1400 with_branch_node_masks: bool,
1402 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
1404}
1405
1406impl StorageProofInput {
1407 pub const fn new(
1410 hashed_address: B256,
1411 prefix_set: PrefixSet,
1412 target_slots: B256Set,
1413 with_branch_node_masks: bool,
1414 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
1415 ) -> Self {
1416 Self {
1417 hashed_address,
1418 prefix_set,
1419 target_slots,
1420 with_branch_node_masks,
1421 multi_added_removed_keys,
1422 }
1423 }
1424}
1425#[derive(Debug, Clone)]
1427pub struct AccountMultiproofInput {
1428 pub targets: MultiProofTargets,
1430 pub prefix_sets: TriePrefixSets,
1432 pub collect_branch_node_masks: bool,
1434 pub multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
1436 pub missed_leaves_storage_roots: Arc<DashMap<B256, B256>>,
1438 pub proof_result_sender: ProofResultContext,
1440}
1441
1442struct AccountMultiproofParams<'a> {
1444 targets: &'a MultiProofTargets,
1446 prefix_set: PrefixSet,
1448 collect_branch_node_masks: bool,
1450 multi_added_removed_keys: Option<&'a Arc<MultiAddedRemovedKeys>>,
1452 storage_proof_receivers: B256Map<CrossbeamReceiver<ProofResultMessage>>,
1454 missed_leaves_storage_roots: &'a DashMap<B256, B256>,
1456}
1457
1458#[derive(Debug)]
1460enum AccountWorkerJob {
1461 AccountMultiproof {
1463 input: Box<AccountMultiproofInput>,
1465 },
1466 BlindedAccountNode {
1468 path: Nibbles,
1470 result_sender: Sender<TrieNodeProviderResult>,
1472 },
1473}
1474
1475#[cfg(test)]
1476mod tests {
1477 use super::*;
1478 use reth_provider::test_utils::create_test_provider_factory;
1479 use reth_trie_common::prefix_set::TriePrefixSetsMut;
1480 use std::sync::Arc;
1481 use tokio::{runtime::Builder, task};
1482
1483 fn test_ctx<Factory>(factory: Factory) -> ProofTaskCtx<Factory> {
1484 ProofTaskCtx::new(factory, Arc::new(TriePrefixSetsMut::default()))
1485 }
1486
1487 #[test]
1489 fn spawn_proof_workers_creates_handle() {
1490 let runtime = Builder::new_multi_thread().worker_threads(1).enable_all().build().unwrap();
1491 runtime.block_on(async {
1492 let handle = tokio::runtime::Handle::current();
1493 let provider_factory = create_test_provider_factory();
1494 let factory =
1495 reth_provider::providers::OverlayStateProviderFactory::new(provider_factory);
1496 let ctx = test_ctx(factory);
1497
1498 let proof_handle = ProofWorkerHandle::new(handle.clone(), ctx, 5, 3);
1499
1500 let _cloned_handle = proof_handle.clone();
1502
1503 drop(proof_handle);
1505 task::yield_now().await;
1506 });
1507 }
1508}