1use crate::root::ParallelStateRootError;
12use alloy_primitives::{map::B256Set, B256};
13use reth_db_api::transaction::DbTx;
14use reth_execution_errors::SparseTrieError;
15use reth_provider::{
16 providers::ConsistentDbView, BlockReader, DBProvider, DatabaseProviderFactory, FactoryTx,
17 ProviderResult, StateCommitmentProvider,
18};
19use reth_trie::{
20 hashed_cursor::HashedPostStateCursorFactory,
21 prefix_set::TriePrefixSetsMut,
22 proof::{ProofBlindedProviderFactory, StorageProof},
23 trie_cursor::InMemoryTrieCursorFactory,
24 updates::TrieUpdatesSorted,
25 HashedPostStateSorted, Nibbles, StorageMultiProof,
26};
27use reth_trie_common::prefix_set::{PrefixSet, PrefixSetMut};
28use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
29use reth_trie_sparse::blinded::{BlindedProvider, BlindedProviderFactory, RevealedNode};
30use std::{
31 collections::VecDeque,
32 sync::{
33 atomic::{AtomicUsize, Ordering},
34 mpsc::{channel, Receiver, SendError, Sender},
35 Arc,
36 },
37 time::Instant,
38};
39use tokio::runtime::Handle;
40use tracing::debug;
41
42type StorageProofResult = Result<StorageMultiProof, ParallelStateRootError>;
43type BlindedNodeResult = Result<Option<RevealedNode>, SparseTrieError>;
44
45#[derive(Debug)]
48pub struct ProofTaskManager<Factory: DatabaseProviderFactory> {
49 max_concurrency: usize,
51 total_transactions: usize,
53 view: ConsistentDbView<Factory>,
55 task_ctx: ProofTaskCtx,
57 pending_tasks: VecDeque<ProofTaskKind>,
59 executor: Handle,
61 proof_task_txs: Vec<ProofTaskTx<FactoryTx<Factory>>>,
64 proof_task_rx: Receiver<ProofTaskMessage<FactoryTx<Factory>>>,
66 tx_sender: Sender<ProofTaskMessage<FactoryTx<Factory>>>,
68 active_handles: Arc<AtomicUsize>,
73}
74
75impl<Factory: DatabaseProviderFactory> ProofTaskManager<Factory> {
76 pub fn new(
81 executor: Handle,
82 view: ConsistentDbView<Factory>,
83 task_ctx: ProofTaskCtx,
84 max_concurrency: usize,
85 ) -> Self {
86 let (tx_sender, proof_task_rx) = channel();
87 Self {
88 max_concurrency,
89 total_transactions: 0,
90 view,
91 task_ctx,
92 pending_tasks: VecDeque::new(),
93 executor,
94 proof_task_txs: Vec::new(),
95 proof_task_rx,
96 tx_sender,
97 active_handles: Arc::new(AtomicUsize::new(0)),
98 }
99 }
100
101 pub fn handle(&self) -> ProofTaskManagerHandle<FactoryTx<Factory>> {
103 ProofTaskManagerHandle::new(self.tx_sender.clone(), self.active_handles.clone())
104 }
105}
106
107impl<Factory> ProofTaskManager<Factory>
108where
109 Factory: DatabaseProviderFactory<Provider: BlockReader> + StateCommitmentProvider + 'static,
110{
111 pub fn queue_proof_task(&mut self, task: ProofTaskKind) {
113 self.pending_tasks.push_back(task);
114 }
115
116 pub fn get_or_create_tx(&mut self) -> ProviderResult<Option<ProofTaskTx<FactoryTx<Factory>>>> {
119 if let Some(proof_task_tx) = self.proof_task_txs.pop() {
120 return Ok(Some(proof_task_tx));
121 }
122
123 if self.total_transactions < self.max_concurrency {
125 let provider_ro = self.view.provider_ro()?;
126 let tx = provider_ro.into_tx();
127 self.total_transactions += 1;
128 return Ok(Some(ProofTaskTx::new(tx, self.task_ctx.clone())));
129 }
130
131 Ok(None)
132 }
133
134 pub fn try_spawn_next(&mut self) -> ProviderResult<()> {
140 let Some(task) = self.pending_tasks.pop_front() else { return Ok(()) };
141
142 let Some(proof_task_tx) = self.get_or_create_tx()? else {
143 self.pending_tasks.push_front(task);
145 return Ok(())
146 };
147
148 let tx_sender = self.tx_sender.clone();
149 self.executor.spawn_blocking(move || match task {
150 ProofTaskKind::StorageProof(input, sender) => {
151 proof_task_tx.storage_proof(input, sender, tx_sender);
152 }
153 ProofTaskKind::BlindedAccountNode(path, sender) => {
154 proof_task_tx.blinded_account_node(path, sender, tx_sender);
155 }
156 ProofTaskKind::BlindedStorageNode(account, path, sender) => {
157 proof_task_tx.blinded_storage_node(account, path, sender, tx_sender);
158 }
159 });
160
161 Ok(())
162 }
163
164 pub fn run(mut self) -> ProviderResult<()> {
166 loop {
167 match self.proof_task_rx.recv() {
168 Ok(message) => match message {
169 ProofTaskMessage::QueueTask(task) => {
170 self.queue_proof_task(task)
172 }
173 ProofTaskMessage::Transaction(tx) => {
174 self.proof_task_txs.push(tx);
176 }
177 ProofTaskMessage::Terminate => return Ok(()),
178 },
179 Err(_) => return Ok(()),
182 };
183
184 self.try_spawn_next()?;
186 }
187 }
188}
189
190#[derive(Debug)]
192pub struct ProofTaskTx<Tx> {
193 tx: Tx,
195
196 task_ctx: ProofTaskCtx,
198}
199
200impl<Tx> ProofTaskTx<Tx> {
201 const fn new(tx: Tx, task_ctx: ProofTaskCtx) -> Self {
203 Self { tx, task_ctx }
204 }
205}
206
207impl<Tx> ProofTaskTx<Tx>
208where
209 Tx: DbTx,
210{
211 fn create_factories(
212 &self,
213 ) -> (
214 InMemoryTrieCursorFactory<'_, DatabaseTrieCursorFactory<'_, Tx>>,
215 HashedPostStateCursorFactory<'_, DatabaseHashedCursorFactory<'_, Tx>>,
216 ) {
217 let trie_cursor_factory = InMemoryTrieCursorFactory::new(
218 DatabaseTrieCursorFactory::new(&self.tx),
219 &self.task_ctx.nodes_sorted,
220 );
221
222 let hashed_cursor_factory = HashedPostStateCursorFactory::new(
223 DatabaseHashedCursorFactory::new(&self.tx),
224 &self.task_ctx.state_sorted,
225 );
226
227 (trie_cursor_factory, hashed_cursor_factory)
228 }
229
230 fn storage_proof(
232 self,
233 input: StorageProofInput,
234 result_sender: Sender<StorageProofResult>,
235 tx_sender: Sender<ProofTaskMessage<Tx>>,
236 ) {
237 debug!(
238 target: "trie::proof_task",
239 hashed_address=?input.hashed_address,
240 "Starting storage proof task calculation"
241 );
242
243 let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
244
245 let target_slots_len = input.target_slots.len();
246 let proof_start = Instant::now();
247 let result = StorageProof::new_hashed(
248 trie_cursor_factory,
249 hashed_cursor_factory,
250 input.hashed_address,
251 )
252 .with_prefix_set_mut(PrefixSetMut::from(input.prefix_set.iter().cloned()))
253 .with_branch_node_masks(input.with_branch_node_masks)
254 .storage_multiproof(input.target_slots)
255 .map_err(|e| ParallelStateRootError::Other(e.to_string()));
256
257 debug!(
258 target: "trie::proof_task",
259 hashed_address=?input.hashed_address,
260 prefix_set = ?input.prefix_set.len(),
261 target_slots = ?target_slots_len,
262 proof_time = ?proof_start.elapsed(),
263 "Completed storage proof task calculation"
264 );
265
266 if let Err(error) = result_sender.send(result) {
268 debug!(
269 target: "trie::proof_task",
270 hashed_address = ?input.hashed_address,
271 ?error,
272 task_time = ?proof_start.elapsed(),
273 "Failed to send proof result"
274 );
275 }
276
277 let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
279 }
280
281 fn blinded_account_node(
283 self,
284 path: Nibbles,
285 result_sender: Sender<BlindedNodeResult>,
286 tx_sender: Sender<ProofTaskMessage<Tx>>,
287 ) {
288 debug!(
289 target: "trie::proof_task",
290 ?path,
291 "Starting blinded account node retrieval"
292 );
293
294 let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
295
296 let blinded_provider_factory = ProofBlindedProviderFactory::new(
297 trie_cursor_factory,
298 hashed_cursor_factory,
299 self.task_ctx.prefix_sets.clone(),
300 );
301
302 let start = Instant::now();
303 let result = blinded_provider_factory.account_node_provider().blinded_node(&path);
304 debug!(
305 target: "trie::proof_task",
306 ?path,
307 elapsed = ?start.elapsed(),
308 "Completed blinded account node retrieval"
309 );
310
311 if let Err(error) = result_sender.send(result) {
312 tracing::error!(
313 target: "trie::proof_task",
314 ?path,
315 ?error,
316 "Failed to send blinded account node result"
317 );
318 }
319
320 let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
322 }
323
324 fn blinded_storage_node(
326 self,
327 account: B256,
328 path: Nibbles,
329 result_sender: Sender<BlindedNodeResult>,
330 tx_sender: Sender<ProofTaskMessage<Tx>>,
331 ) {
332 debug!(
333 target: "trie::proof_task",
334 ?account,
335 ?path,
336 "Starting blinded storage node retrieval"
337 );
338
339 let (trie_cursor_factory, hashed_cursor_factory) = self.create_factories();
340
341 let blinded_provider_factory = ProofBlindedProviderFactory::new(
342 trie_cursor_factory,
343 hashed_cursor_factory,
344 self.task_ctx.prefix_sets.clone(),
345 );
346
347 let start = Instant::now();
348 let result = blinded_provider_factory.storage_node_provider(account).blinded_node(&path);
349 debug!(
350 target: "trie::proof_task",
351 ?account,
352 ?path,
353 elapsed = ?start.elapsed(),
354 "Completed blinded storage node retrieval"
355 );
356
357 if let Err(error) = result_sender.send(result) {
358 tracing::error!(
359 target: "trie::proof_task",
360 ?account,
361 ?path,
362 ?error,
363 "Failed to send blinded storage node result"
364 );
365 }
366
367 let _ = tx_sender.send(ProofTaskMessage::Transaction(self));
369 }
370}
371
372#[derive(Debug)]
374pub struct StorageProofInput {
375 hashed_address: B256,
377 prefix_set: PrefixSet,
379 target_slots: B256Set,
381 with_branch_node_masks: bool,
383}
384
385impl StorageProofInput {
386 pub const fn new(
389 hashed_address: B256,
390 prefix_set: PrefixSet,
391 target_slots: B256Set,
392 with_branch_node_masks: bool,
393 ) -> Self {
394 Self { hashed_address, prefix_set, target_slots, with_branch_node_masks }
395 }
396}
397
398#[derive(Debug, Clone)]
400pub struct ProofTaskCtx {
401 nodes_sorted: Arc<TrieUpdatesSorted>,
404 state_sorted: Arc<HashedPostStateSorted>,
406 prefix_sets: Arc<TriePrefixSetsMut>,
410}
411
412impl ProofTaskCtx {
413 pub const fn new(
415 nodes_sorted: Arc<TrieUpdatesSorted>,
416 state_sorted: Arc<HashedPostStateSorted>,
417 prefix_sets: Arc<TriePrefixSetsMut>,
418 ) -> Self {
419 Self { nodes_sorted, state_sorted, prefix_sets }
420 }
421}
422
423#[derive(Debug)]
425pub enum ProofTaskMessage<Tx> {
426 QueueTask(ProofTaskKind),
428 Transaction(ProofTaskTx<Tx>),
430 Terminate,
432}
433
434#[derive(Debug)]
439pub enum ProofTaskKind {
440 StorageProof(StorageProofInput, Sender<StorageProofResult>),
442 BlindedAccountNode(Nibbles, Sender<BlindedNodeResult>),
444 BlindedStorageNode(B256, Nibbles, Sender<BlindedNodeResult>),
446}
447
448#[derive(Debug)]
451pub struct ProofTaskManagerHandle<Tx> {
452 sender: Sender<ProofTaskMessage<Tx>>,
454 active_handles: Arc<AtomicUsize>,
456}
457
458impl<Tx> ProofTaskManagerHandle<Tx> {
459 pub fn new(sender: Sender<ProofTaskMessage<Tx>>, active_handles: Arc<AtomicUsize>) -> Self {
461 active_handles.fetch_add(1, Ordering::SeqCst);
462 Self { sender, active_handles }
463 }
464
465 pub fn queue_task(&self, task: ProofTaskKind) -> Result<(), SendError<ProofTaskMessage<Tx>>> {
467 self.sender.send(ProofTaskMessage::QueueTask(task))
468 }
469
470 pub fn terminate(&self) {
472 let _ = self.sender.send(ProofTaskMessage::Terminate);
473 }
474}
475
476impl<Tx> Clone for ProofTaskManagerHandle<Tx> {
477 fn clone(&self) -> Self {
478 Self::new(self.sender.clone(), self.active_handles.clone())
479 }
480}
481
482impl<Tx> Drop for ProofTaskManagerHandle<Tx> {
483 fn drop(&mut self) {
484 if self.active_handles.fetch_sub(1, Ordering::SeqCst) == 1 {
487 self.terminate();
488 }
489 }
490}
491
492impl<Tx: DbTx> BlindedProviderFactory for ProofTaskManagerHandle<Tx> {
493 type AccountNodeProvider = ProofTaskBlindedNodeProvider<Tx>;
494 type StorageNodeProvider = ProofTaskBlindedNodeProvider<Tx>;
495
496 fn account_node_provider(&self) -> Self::AccountNodeProvider {
497 ProofTaskBlindedNodeProvider::AccountNode { sender: self.sender.clone() }
498 }
499
500 fn storage_node_provider(&self, account: B256) -> Self::StorageNodeProvider {
501 ProofTaskBlindedNodeProvider::StorageNode { account, sender: self.sender.clone() }
502 }
503}
504
505#[derive(Debug)]
507pub enum ProofTaskBlindedNodeProvider<Tx> {
508 AccountNode {
510 sender: Sender<ProofTaskMessage<Tx>>,
512 },
513 StorageNode {
515 account: B256,
517 sender: Sender<ProofTaskMessage<Tx>>,
519 },
520}
521
522impl<Tx: DbTx> BlindedProvider for ProofTaskBlindedNodeProvider<Tx> {
523 fn blinded_node(&self, path: &Nibbles) -> Result<Option<RevealedNode>, SparseTrieError> {
524 let (tx, rx) = channel();
525 match self {
526 Self::AccountNode { sender } => {
527 let _ = sender.send(ProofTaskMessage::QueueTask(
528 ProofTaskKind::BlindedAccountNode(path.clone(), tx),
529 ));
530 }
531 Self::StorageNode { sender, account } => {
532 let _ = sender.send(ProofTaskMessage::QueueTask(
533 ProofTaskKind::BlindedStorageNode(*account, path.clone(), tx),
534 ));
535 }
536 }
537
538 rx.recv().unwrap()
539 }
540}