1use crate::root::ParallelStateRootError;
12use alloy_primitives::{map::B256Set, B256};
13use reth_db_api::transaction::DbTx;
14use reth_execution_errors::SparseTrieError;
15use reth_provider::{
16 providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
17 ProviderResult,
18};
19use reth_trie::{
20 hashed_cursor::HashedPostStateCursorFactory,
21 prefix_set::TriePrefixSetsMut,
22 proof::{ProofTrieNodeProviderFactory, StorageProof},
23 trie_cursor::InMemoryTrieCursorFactory,
24 updates::TrieUpdatesSorted,
25 DecodedStorageMultiProof, HashedPostStateSorted, Nibbles,
26};
27use reth_trie_common::{
28 added_removed_keys::MultiAddedRemovedKeys,
29 prefix_set::{PrefixSet, PrefixSetMut},
30};
31use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
32use reth_trie_sparse::provider::{RevealedNode, TrieNodeProvider, TrieNodeProviderFactory};
33use std::{
34 collections::VecDeque,
35 sync::{
36 atomic::{AtomicUsize, Ordering},
37 mpsc::{channel, Receiver, SendError, Sender},
38 Arc,
39 },
40 time::Instant,
41};
42use tokio::runtime::Handle;
43use tracing::debug;
44
45#[cfg(feature = "metrics")]
46use crate::proof_task_metrics::ProofTaskMetrics;
47
48type StorageProofResult = Result<DecodedStorageMultiProof, ParallelStateRootError>;
49type TrieNodeProviderResult = Result<Option<RevealedNode>, SparseTrieError>;
50
51#[derive(Debug)]
54pub struct ProofTaskManager<Factory: DatabaseProviderFactory> {
55 max_concurrency: usize,
57 total_transactions: usize,
59 view: ConsistentDbView<Factory>,
61 task_ctx: ProofTaskCtx,
63 pending_tasks: VecDeque<ProofTaskKind>,
65 executor: Handle,
67 proof_task_txs: Vec<ProofTaskTx<FactoryTx<Factory>>>,
70 proof_task_rx: Receiver<ProofTaskMessage<FactoryTx<Factory>>>,
72 tx_sender: Sender<ProofTaskMessage<FactoryTx<Factory>>>,
74 active_handles: Arc<AtomicUsize>,
79 #[cfg(feature = "metrics")]
81 metrics: ProofTaskMetrics,
82}
83
84impl<Factory: DatabaseProviderFactory> ProofTaskManager<Factory> {
85 pub fn new(
90 executor: Handle,
91 view: ConsistentDbView<Factory>,
92 task_ctx: ProofTaskCtx,
93 max_concurrency: usize,
94 ) -> Self {
95 let (tx_sender, proof_task_rx) = channel();
96 Self {
97 max_concurrency,
98 total_transactions: 0,
99 view,
100 task_ctx,
101 pending_tasks: VecDeque::new(),
102 executor,
103 proof_task_txs: Vec::new(),
104 proof_task_rx,
105 tx_sender,
106 active_handles: Arc::new(AtomicUsize::new(0)),
107 #[cfg(feature = "metrics")]
108 metrics: ProofTaskMetrics::default(),
109 }
110 }
111
112 pub fn handle(&self) -> ProofTaskManagerHandle<FactoryTx<Factory>> {
114 ProofTaskManagerHandle::new(self.tx_sender.clone(), self.active_handles.clone())
115 }
116}
117
118impl<Factory> ProofTaskManager<Factory>
119where
120 Factory: DatabaseProviderFactory<Provider: BlockReader> + 'static,
121{
122 pub fn queue_proof_task(&mut self, task: ProofTaskKind) {
124 self.pending_tasks.push_back(task);
125 }
126
127 pub fn get_or_create_tx(&mut self) -> ProviderResult<Option<ProofTaskTx<FactoryTx<Factory>>>> {
130 if let Some(proof_task_tx) = self.proof_task_txs.pop() {
131 return Ok(Some(proof_task_tx));
132 }
133
134 if self.total_transactions < self.max_concurrency {
136 let provider_ro = self.view.provider_ro()?;
137 let tx = provider_ro.into_tx();
138 self.total_transactions += 1;
139 return Ok(Some(ProofTaskTx::new(tx, self.task_ctx.clone(), self.total_transactions)));
140 }
141
142 Ok(None)
143 }
144
145 pub fn try_spawn_next(&mut self) -> ProviderResult<()> {
151 let Some(task) = self.pending_tasks.pop_front() else { return Ok(()) };
152
153 let Some(proof_task_tx) = self.get_or_create_tx()? else {
154 self.pending_tasks.push_front(task);
156 return Ok(())
157 };
158
159 let tx_sender = self.tx_sender.clone();
160 self.executor.spawn_blocking(move || match task {
161 ProofTaskKind::StorageProof(input, sender) => {
162 proof_task_tx.storage_proof(input, sender, tx_sender);
163 }
164 ProofTaskKind::BlindedAccountNode(path, sender) => {
165 proof_task_tx.blinded_account_node(path, sender, tx_sender);
166 }
167 ProofTaskKind::BlindedStorageNode(account, path, sender) => {
168 proof_task_tx.blinded_storage_node(account, path, sender, tx_sender);
169 }
170 });
171
172 Ok(())
173 }
174
175 pub fn run(mut self) -> ProviderResult<()> {
177 loop {
178 match self.proof_task_rx.recv() {
179 Ok(message) => match message {
180 ProofTaskMessage::QueueTask(task) => {
181 #[cfg(feature = "metrics")]
183 match &task {
184 ProofTaskKind::BlindedAccountNode(_, _) => {
185 self.metrics.account_nodes += 1;
186 }
187 ProofTaskKind::BlindedStorageNode(_, _, _) => {
188 self.metrics.storage_nodes += 1;
189 }
190 _ => {}
191 }
192 self.queue_proof_task(task)
194 }
195 ProofTaskMessage::Transaction(tx) => {
196 self.proof_task_txs.push(tx);
198 }
199 ProofTaskMessage::Terminate => {
200 #[cfg(feature = "metrics")]
202 self.metrics.record();
203 return Ok(())
204 }
205 },
206 Err(_) => return Ok(()),
209 };
210
211 self.try_spawn_next()?;
213 }
214 }
215}
216
217#[derive(Debug)]
219pub struct ProofTaskTx<Tx> {
220 tx: Tx,
222
223 task_ctx: ProofTaskCtx,
225
226 id: usize,
229}
230
231impl<Tx> ProofTaskTx<Tx> {
232 const fn new(tx: Tx, task_ctx: ProofTaskCtx, id: usize) -> Self {
235 Self { tx, task_ctx, id }
236 }
237}
238
239impl<Tx> ProofTaskTx<Tx>
240where
241 Tx: DbTx,
242{
243 fn create_factories(
244 &self,
245 ) -> (
246 InMemoryTrieCursorFactory<'_, DatabaseTrieCursorFactory<'_, Tx>>,
247 HashedPostStateCursorFactory<'_, DatabaseHashedCursorFactory<'_, Tx>>,
248 ) {
249 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
250 DatabaseTrieCursorFactory::new(&self.tx),
251 &self.task_ctx.nodes_sorted,
252 );
253
254 let hashed_cursor_factory = HashedPostStateCursorFactory::new(
255 DatabaseHashedCursorFactory::new(&self.tx),
256 &self.task_ctx.state_sorted,
257 );
258
259 (trie_cursor_factory, hashed_cursor_factory)
260 }
261
262 fn storage_proof(
264 self,
265 input: StorageProofInput,
266 result_sender: Sender<StorageProofResult>,
267 tx_sender: Sender<ProofTaskMessage<Tx>>,
268 ) {
269 debug!(
270 target: "trie::proof_task",
271 hashed_address=?input.hashed_address,
272 "Starting storage proof task calculation"
273 );
274
275 let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
276 let multi_added_removed_keys = input
277 .multi_added_removed_keys
278 .unwrap_or_else(|| Arc::new(MultiAddedRemovedKeys::new()));
279 let added_removed_keys = multi_added_removed_keys.get_storage(&input.hashed_address);
280
281 let span = tracing::trace_span!(
282 target: "trie::proof_task",
283 "Storage proof calculation",
284 hashed_address=?input.hashed_address,
285 span_id=self.id,
288 );
289 let span_guard = span.enter();
290
291 let target_slots_len = input.target_slots.len();
292 let proof_start = Instant::now();
293
294 let raw_proof_result = StorageProof::new_hashed(
295 trie_cursor_factory,
296 hashed_cursor_factory,
297 input.hashed_address,
298 )
299 .with_prefix_set_mut(PrefixSetMut::from(input.prefix_set.iter().copied()))
300 .with_branch_node_masks(input.with_branch_node_masks)
301 .with_added_removed_keys(added_removed_keys)
302 .storage_multiproof(input.target_slots)
303 .map_err(|e| ParallelStateRootError::Other(e.to_string()));
304
305 drop(span_guard);
306
307 let decoded_result = raw_proof_result.and_then(|raw_proof| {
308 raw_proof.try_into().map_err(|e: alloy_rlp::Error| {
309 ParallelStateRootError::Other(format!(
310 "Failed to decode storage proof for {}: {}",
311 input.hashed_address, e
312 ))
313 })
314 });
315
316 debug!(
317 target: "trie::proof_task",
318 hashed_address=?input.hashed_address,
319 prefix_set = ?input.prefix_set.len(),
320 target_slots = ?target_slots_len,
321 proof_time = ?proof_start.elapsed(),
322 "Completed storage proof task calculation"
323 );
324
325 if let Err(error) = result_sender.send(decoded_result) {
327 debug!(
328 target: "trie::proof_task",
329 hashed_address = ?input.hashed_address,
330 ?error,
331 task_time = ?proof_start.elapsed(),
332 "Storage proof receiver is dropped, discarding the result"
333 );
334 }
335
336 let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
338 }
339
340 fn blinded_account_node(
342 self,
343 path: Nibbles,
344 result_sender: Sender<TrieNodeProviderResult>,
345 tx_sender: Sender<ProofTaskMessage<Tx>>,
346 ) {
347 debug!(
348 target: "trie::proof_task",
349 ?path,
350 "Starting blinded account node retrieval"
351 );
352
353 let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
354
355 let blinded_provider_factory = ProofTrieNodeProviderFactory::new(
356 trie_cursor_factory,
357 hashed_cursor_factory,
358 self.task_ctx.prefix_sets.clone(),
359 );
360
361 let start = Instant::now();
362 let result = blinded_provider_factory.account_node_provider().trie_node(&path);
363 debug!(
364 target: "trie::proof_task",
365 ?path,
366 elapsed = ?start.elapsed(),
367 "Completed blinded account node retrieval"
368 );
369
370 if let Err(error) = result_sender.send(result) {
371 tracing::error!(
372 target: "trie::proof_task",
373 ?path,
374 ?error,
375 "Failed to send blinded account node result"
376 );
377 }
378
379 let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
381 }
382
383 fn blinded_storage_node(
385 self,
386 account: B256,
387 path: Nibbles,
388 result_sender: Sender<TrieNodeProviderResult>,
389 tx_sender: Sender<ProofTaskMessage<Tx>>,
390 ) {
391 debug!(
392 target: "trie::proof_task",
393 ?account,
394 ?path,
395 "Starting blinded storage node retrieval"
396 );
397
398 let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
399
400 let blinded_provider_factory = ProofTrieNodeProviderFactory::new(
401 trie_cursor_factory,
402 hashed_cursor_factory,
403 self.task_ctx.prefix_sets.clone(),
404 );
405
406 let start = Instant::now();
407 let result = blinded_provider_factory.storage_node_provider(account).trie_node(&path);
408 debug!(
409 target: "trie::proof_task",
410 ?account,
411 ?path,
412 elapsed = ?start.elapsed(),
413 "Completed blinded storage node retrieval"
414 );
415
416 if let Err(error) = result_sender.send(result) {
417 tracing::error!(
418 target: "trie::proof_task",
419 ?account,
420 ?path,
421 ?error,
422 "Failed to send blinded storage node result"
423 );
424 }
425
426 let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
428 }
429}
430
431#[derive(Debug)]
433pub struct StorageProofInput {
434 hashed_address: B256,
436 prefix_set: PrefixSet,
438 target_slots: B256Set,
440 with_branch_node_masks: bool,
442 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
444}
445
446impl StorageProofInput {
447 pub const fn new(
450 hashed_address: B256,
451 prefix_set: PrefixSet,
452 target_slots: B256Set,
453 with_branch_node_masks: bool,
454 multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
455 ) -> Self {
456 Self {
457 hashed_address,
458 prefix_set,
459 target_slots,
460 with_branch_node_masks,
461 multi_added_removed_keys,
462 }
463 }
464}
465
466#[derive(Debug, Clone)]
468pub struct ProofTaskCtx {
469 nodes_sorted: Arc<TrieUpdatesSorted>,
472 state_sorted: Arc<HashedPostStateSorted>,
474 prefix_sets: Arc<TriePrefixSetsMut>,
478}
479
480impl ProofTaskCtx {
481 pub const fn new(
483 nodes_sorted: Arc<TrieUpdatesSorted>,
484 state_sorted: Arc<HashedPostStateSorted>,
485 prefix_sets: Arc<TriePrefixSetsMut>,
486 ) -> Self {
487 Self { nodes_sorted, state_sorted, prefix_sets }
488 }
489}
490
491#[derive(Debug)]
493pub enum ProofTaskMessage<Tx> {
494 QueueTask(ProofTaskKind),
496 Transaction(ProofTaskTx<Tx>),
498 Terminate,
500}
501
502#[derive(Debug)]
507pub enum ProofTaskKind {
508 StorageProof(StorageProofInput, Sender<StorageProofResult>),
510 BlindedAccountNode(Nibbles, Sender<TrieNodeProviderResult>),
512 BlindedStorageNode(B256, Nibbles, Sender<TrieNodeProviderResult>),
514}
515
516#[derive(Debug)]
519pub struct ProofTaskManagerHandle<Tx> {
520 sender: Sender<ProofTaskMessage<Tx>>,
522 active_handles: Arc<AtomicUsize>,
524}
525
526impl<Tx> ProofTaskManagerHandle<Tx> {
527 pub fn new(sender: Sender<ProofTaskMessage<Tx>>, active_handles: Arc<AtomicUsize>) -> Self {
529 active_handles.fetch_add(1, Ordering::SeqCst);
530 Self { sender, active_handles }
531 }
532
533 pub fn queue_task(&self, task: ProofTaskKind) -> Result<(), SendError<ProofTaskMessage<Tx>>> {
535 self.sender.send(ProofTaskMessage::QueueTask(task))
536 }
537
538 pub fn terminate(&self) {
540 let _ = self.sender.send(ProofTaskMessage::Terminate);
541 }
542}
543
544impl<Tx> Clone for ProofTaskManagerHandle<Tx> {
545 fn clone(&self) -> Self {
546 Self::new(self.sender.clone(), self.active_handles.clone())
547 }
548}
549
550impl<Tx> Drop for ProofTaskManagerHandle<Tx> {
551 fn drop(&mut self) {
552 if self.active_handles.fetch_sub(1, Ordering::SeqCst) == 1 {
555 self.terminate();
556 }
557 }
558}
559
560impl<Tx: DbTx> TrieNodeProviderFactory for ProofTaskManagerHandle<Tx> {
561 type AccountNodeProvider = ProofTaskTrieNodeProvider<Tx>;
562 type StorageNodeProvider = ProofTaskTrieNodeProvider<Tx>;
563
564 fn account_node_provider(&self) -> Self::AccountNodeProvider {
565 ProofTaskTrieNodeProvider::AccountNode { sender: self.sender.clone() }
566 }
567
568 fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
569 ProofTaskTrieNodeProvider::StorageNode { account, sender: self.sender.clone() }
570 }
571}
572
573#[derive(Debug)]
575pub enum ProofTaskTrieNodeProvider<Tx> {
576 AccountNode {
578 sender: Sender<ProofTaskMessage<Tx>>,
580 },
581 StorageNode {
583 account: B256,
585 sender: Sender<ProofTaskMessage<Tx>>,
587 },
588}
589
590impl<Tx: DbTx> TrieNodeProvider for ProofTaskTrieNodeProvider<Tx> {
591 fn trie_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
592 let (tx, rx) = channel();
593 match self {
594 Self::AccountNode { sender } => {
595 let _ = sender.send(ProofTaskMessage::QueueTask(
596 ProofTaskKind::BlindedAccountNode(*path, tx),
597 ));
598 }
599 Self::StorageNode { sender, account } => {
600 let _ = sender.send(ProofTaskMessage::QueueTask(
601 ProofTaskKind::BlindedStorageNode(*account, *path, tx),
602 ));
603 }
604 }
605
606 rx.recv().unwrap()
607 }
608}