reth_engine_tree/tree/payload_processor/
multiproof.rs

1//! Multiproof task related functionality.
2
3use crate::tree::payload_processor::executor::WorkloadExecutor;
4use alloy_evm::block::StateChangeSource;
5use alloy_primitives::{
6    keccak256,
7    map::{B256Set, HashSet},
8    B256,
9};
10use derive_more::derive::Deref;
11use metrics::Histogram;
12use reth_errors::ProviderError;
13use reth_metrics::Metrics;
14use reth_provider::{providers::ConsistentDbView, BlockReader, DatabaseProviderFactory, FactoryTx};
15use reth_revm::state::EvmState;
16use reth_trie::{
17    added_removed_keys::MultiAddedRemovedKeys, prefix_set::TriePrefixSetsMut,
18    updates::TrieUpdatesSorted, DecodedMultiProof, HashedPostState, HashedPostStateSorted,
19    HashedStorage, MultiProofTargets, TrieInput,
20};
21use reth_trie_parallel::{proof::ParallelProof, proof_task::ProofTaskManagerHandle};
22use std::{
23    collections::{BTreeMap, VecDeque},
24    ops::DerefMut,
25    sync::{
26        mpsc::{channel, Receiver, Sender},
27        Arc,
28    },
29    time::{Duration, Instant},
30};
31use tracing::{debug, error, trace};
32
33/// The size of proof targets chunk to spawn in one calculation.
34const MULTIPROOF_TARGETS_CHUNK_SIZE: usize = 10;
35
36/// A trie update that can be applied to sparse trie alongside the proofs for touched parts of the
37/// state.
38#[derive(Default, Debug)]
39pub struct SparseTrieUpdate {
40    /// The state update that was used to calculate the proof
41    pub(crate) state: HashedPostState,
42    /// The calculated multiproof
43    pub(crate) multiproof: DecodedMultiProof,
44}
45
46impl SparseTrieUpdate {
47    /// Returns true if the update is empty.
48    pub(super) fn is_empty(&self) -> bool {
49        self.state.is_empty() && self.multiproof.is_empty()
50    }
51
52    /// Construct update from multiproof.
53    #[cfg(test)]
54    pub(super) fn from_multiproof(multiproof: reth_trie::MultiProof) -> alloy_rlp::Result<Self> {
55        Ok(Self { multiproof: multiproof.try_into()?, ..Default::default() })
56    }
57
58    /// Extend update with contents of the other.
59    pub(super) fn extend(&mut self, other: Self) {
60        self.state.extend(other.state);
61        self.multiproof.extend(other.multiproof);
62    }
63}
64
65/// Common configuration for multi proof tasks
66#[derive(Debug, Clone)]
67pub(super) struct MultiProofConfig<Factory> {
68    /// View over the state in the database.
69    pub consistent_view: ConsistentDbView<Factory>,
70    /// The sorted collection of cached in-memory intermediate trie nodes that
71    /// can be reused for computation.
72    pub nodes_sorted: Arc<TrieUpdatesSorted>,
73    /// The sorted in-memory overlay hashed state.
74    pub state_sorted: Arc<HashedPostStateSorted>,
75    /// The collection of prefix sets for the computation. Since the prefix sets _always_
76    /// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
77    /// if we have cached nodes for them.
78    pub prefix_sets: Arc<TriePrefixSetsMut>,
79}
80
81impl<Factory> MultiProofConfig<Factory> {
82    /// Creates a new state root config from the consistent view and the trie input.
83    ///
84    /// This returns a cleared [`TrieInput`] so that we can reuse any allocated space in the
85    /// [`TrieInput`].
86    pub(super) fn new_from_input(
87        consistent_view: ConsistentDbView<Factory>,
88        mut input: TrieInput,
89    ) -> (TrieInput, Self) {
90        let config = Self {
91            consistent_view,
92            nodes_sorted: Arc::new(input.nodes.drain_into_sorted()),
93            state_sorted: Arc::new(input.state.drain_into_sorted()),
94            prefix_sets: Arc::new(input.prefix_sets.clone()),
95        };
96        (input.cleared(), config)
97    }
98}
99
100/// Messages used internally by the multi proof task.
101#[derive(Debug)]
102pub(super) enum MultiProofMessage {
103    /// Prefetch proof targets
104    PrefetchProofs(MultiProofTargets),
105    /// New state update from transaction execution with its source
106    StateUpdate(StateChangeSource, EvmState),
107    /// State update that can be applied to the sparse trie without any new proofs.
108    ///
109    /// It can be the case when all accounts and storage slots from the state update were already
110    /// fetched and revealed.
111    EmptyProof {
112        /// The index of this proof in the sequence of state updates
113        sequence_number: u64,
114        /// The state update that was used to calculate the proof
115        state: HashedPostState,
116    },
117    /// Proof calculation completed for a specific state update
118    ProofCalculated(Box<ProofCalculated>),
119    /// Error during proof calculation
120    ProofCalculationError(ProviderError),
121    /// Signals state update stream end.
122    ///
123    /// This is triggered by block execution, indicating that no additional state updates are
124    /// expected.
125    FinishedStateUpdates,
126}
127
128/// Message about completion of proof calculation for a specific state update
129#[derive(Debug)]
130pub(super) struct ProofCalculated {
131    /// The index of this proof in the sequence of state updates
132    sequence_number: u64,
133    /// Sparse trie update
134    update: SparseTrieUpdate,
135    /// The time taken to calculate the proof.
136    elapsed: Duration,
137}
138
139/// Handle to track proof calculation ordering.
140#[derive(Debug, Default)]
141struct ProofSequencer {
142    /// The next proof sequence number to be produced.
143    next_sequence: u64,
144    /// The next sequence number expected to be delivered.
145    next_to_deliver: u64,
146    /// Buffer for out-of-order proofs and corresponding state updates
147    pending_proofs: BTreeMap<u64, SparseTrieUpdate>,
148}
149
150impl ProofSequencer {
151    /// Gets the next sequence number and increments the counter
152    const fn next_sequence(&mut self) -> u64 {
153        let seq = self.next_sequence;
154        self.next_sequence += 1;
155        seq
156    }
157
158    /// Adds a proof with the corresponding state update and returns all sequential proofs and state
159    /// updates if we have a continuous sequence
160    fn add_proof(&mut self, sequence: u64, update: SparseTrieUpdate) -> Vec<SparseTrieUpdate> {
161        if sequence >= self.next_to_deliver {
162            self.pending_proofs.insert(sequence, update);
163        }
164
165        // return early if we don't have the next expected proof
166        if !self.pending_proofs.contains_key(&self.next_to_deliver) {
167            return Vec::new()
168        }
169
170        let mut consecutive_proofs = Vec::with_capacity(self.pending_proofs.len());
171        let mut current_sequence = self.next_to_deliver;
172
173        // keep collecting proofs and state updates as long as we have consecutive sequence numbers
174        while let Some(pending) = self.pending_proofs.remove(&current_sequence) {
175            consecutive_proofs.push(pending);
176            current_sequence += 1;
177
178            // if we don't have the next number, stop collecting
179            if !self.pending_proofs.contains_key(&current_sequence) {
180                break;
181            }
182        }
183
184        self.next_to_deliver += consecutive_proofs.len() as u64;
185
186        consecutive_proofs
187    }
188
189    /// Returns true if we still have pending proofs
190    pub(crate) fn has_pending(&self) -> bool {
191        !self.pending_proofs.is_empty()
192    }
193}
194
195/// A wrapper for the sender that signals completion when dropped.
196///
197/// This type is intended to be used in combination with the evm executor statehook.
198/// This should trigger once the block has been executed (after) the last state update has been
199/// sent. This triggers the exit condition of the multi proof task.
200#[derive(Deref, Debug)]
201pub(super) struct StateHookSender(Sender<MultiProofMessage>);
202
203impl StateHookSender {
204    pub(crate) const fn new(inner: Sender<MultiProofMessage>) -> Self {
205        Self(inner)
206    }
207}
208
209impl Drop for StateHookSender {
210    fn drop(&mut self) {
211        // Send completion signal when the sender is dropped
212        let _ = self.0.send(MultiProofMessage::FinishedStateUpdates);
213    }
214}
215
216pub(crate) fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState {
217    let mut hashed_state = HashedPostState::with_capacity(update.len());
218
219    for (address, account) in update {
220        if account.is_touched() {
221            let hashed_address = keccak256(address);
222            trace!(target: "engine::root", ?address, ?hashed_address, "Adding account to state update");
223
224            let destroyed = account.is_selfdestructed();
225            let info = if destroyed { None } else { Some(account.info.into()) };
226            hashed_state.accounts.insert(hashed_address, info);
227
228            let mut changed_storage_iter = account
229                .storage
230                .into_iter()
231                .filter(|(_slot, value)| value.is_changed())
232                .map(|(slot, value)| (keccak256(B256::from(slot)), value.present_value))
233                .peekable();
234
235            if destroyed {
236                hashed_state.storages.insert(hashed_address, HashedStorage::new(true));
237            } else if changed_storage_iter.peek().is_some() {
238                hashed_state
239                    .storages
240                    .insert(hashed_address, HashedStorage::from_iter(false, changed_storage_iter));
241            }
242        }
243    }
244
245    hashed_state
246}
247
248/// A pending multiproof task, either [`StorageMultiproofInput`] or [`MultiproofInput`].
249#[derive(Debug)]
250enum PendingMultiproofTask<Factory> {
251    /// A storage multiproof task input.
252    Storage(StorageMultiproofInput<Factory>),
253    /// A regular multiproof task input.
254    Regular(MultiproofInput<Factory>),
255}
256
257impl<Factory> PendingMultiproofTask<Factory> {
258    /// Returns the proof sequence number of the task.
259    const fn proof_sequence_number(&self) -> u64 {
260        match self {
261            Self::Storage(input) => input.proof_sequence_number,
262            Self::Regular(input) => input.proof_sequence_number,
263        }
264    }
265
266    /// Returns whether or not the proof targets are empty.
267    fn proof_targets_is_empty(&self) -> bool {
268        match self {
269            Self::Storage(input) => input.proof_targets.is_empty(),
270            Self::Regular(input) => input.proof_targets.is_empty(),
271        }
272    }
273
274    /// Destroys the input and sends a [`MultiProofMessage::EmptyProof`] message to the sender.
275    fn send_empty_proof(self) {
276        match self {
277            Self::Storage(input) => input.send_empty_proof(),
278            Self::Regular(input) => input.send_empty_proof(),
279        }
280    }
281}
282
283impl<Factory> From<StorageMultiproofInput<Factory>> for PendingMultiproofTask<Factory> {
284    fn from(input: StorageMultiproofInput<Factory>) -> Self {
285        Self::Storage(input)
286    }
287}
288
289impl<Factory> From<MultiproofInput<Factory>> for PendingMultiproofTask<Factory> {
290    fn from(input: MultiproofInput<Factory>) -> Self {
291        Self::Regular(input)
292    }
293}
294
295/// Input parameters for spawning a dedicated storage multiproof calculation.
296#[derive(Debug)]
297struct StorageMultiproofInput<Factory> {
298    config: MultiProofConfig<Factory>,
299    source: Option<StateChangeSource>,
300    hashed_state_update: HashedPostState,
301    hashed_address: B256,
302    proof_targets: B256Set,
303    proof_sequence_number: u64,
304    state_root_message_sender: Sender<MultiProofMessage>,
305    multi_added_removed_keys: Arc<MultiAddedRemovedKeys>,
306}
307
308impl<Factory> StorageMultiproofInput<Factory> {
309    /// Destroys the input and sends a [`MultiProofMessage::EmptyProof`] message to the sender.
310    fn send_empty_proof(self) {
311        let _ = self.state_root_message_sender.send(MultiProofMessage::EmptyProof {
312            sequence_number: self.proof_sequence_number,
313            state: self.hashed_state_update,
314        });
315    }
316}
317
318/// Input parameters for spawning a multiproof calculation.
319#[derive(Debug)]
320struct MultiproofInput<Factory> {
321    config: MultiProofConfig<Factory>,
322    source: Option<StateChangeSource>,
323    hashed_state_update: HashedPostState,
324    proof_targets: MultiProofTargets,
325    proof_sequence_number: u64,
326    state_root_message_sender: Sender<MultiProofMessage>,
327    multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
328}
329
330impl<Factory> MultiproofInput<Factory> {
331    /// Destroys the input and sends a [`MultiProofMessage::EmptyProof`] message to the sender.
332    fn send_empty_proof(self) {
333        let _ = self.state_root_message_sender.send(MultiProofMessage::EmptyProof {
334            sequence_number: self.proof_sequence_number,
335            state: self.hashed_state_update,
336        });
337    }
338}
339
340/// Manages concurrent multiproof calculations.
341/// Takes care of not having more calculations in flight than a given maximum
342/// concurrency, further calculation requests are queued and spawn later, after
343/// availability has been signaled.
344#[derive(Debug)]
345pub struct MultiproofManager<Factory: DatabaseProviderFactory> {
346    /// Maximum number of concurrent calculations.
347    max_concurrent: usize,
348    /// Currently running calculations.
349    inflight: usize,
350    /// Queued calculations.
351    pending: VecDeque<PendingMultiproofTask<Factory>>,
352    /// Executor for tasks
353    executor: WorkloadExecutor,
354    /// Sender to the storage proof task.
355    storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
356    /// Metrics
357    metrics: MultiProofTaskMetrics,
358}
359
360impl<Factory> MultiproofManager<Factory>
361where
362    Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
363{
364    /// Creates a new [`MultiproofManager`].
365    fn new(
366        executor: WorkloadExecutor,
367        metrics: MultiProofTaskMetrics,
368        storage_proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
369        max_concurrent: usize,
370    ) -> Self {
371        Self {
372            pending: VecDeque::with_capacity(max_concurrent),
373            max_concurrent,
374            executor,
375            inflight: 0,
376            metrics,
377            storage_proof_task_handle,
378        }
379    }
380
381    /// Spawns a new multiproof calculation or enqueues it for later if
382    /// `max_concurrent` are already inflight.
383    fn spawn_or_queue(&mut self, input: PendingMultiproofTask<Factory>) {
384        // If there are no proof targets, we can just send an empty multiproof back immediately
385        if input.proof_targets_is_empty() {
386            debug!(
387                sequence_number = input.proof_sequence_number(),
388                "No proof targets, sending empty multiproof back immediately"
389            );
390            input.send_empty_proof();
391            return
392        }
393
394        if self.inflight >= self.max_concurrent {
395            self.pending.push_back(input);
396            self.metrics.pending_multiproofs_histogram.record(self.pending.len() as f64);
397            return;
398        }
399
400        self.spawn_multiproof_task(input);
401    }
402
403    /// Signals that a multiproof calculation has finished and there's room to
404    /// spawn a new calculation if needed.
405    fn on_calculation_complete(&mut self) {
406        self.inflight = self.inflight.saturating_sub(1);
407        self.metrics.inflight_multiproofs_histogram.record(self.inflight as f64);
408
409        if let Some(input) = self.pending.pop_front() {
410            self.metrics.pending_multiproofs_histogram.record(self.pending.len() as f64);
411            self.spawn_multiproof_task(input);
412        }
413    }
414
415    /// Spawns a multiproof task, dispatching to `spawn_storage_proof` if the input is a storage
416    /// multiproof, and dispatching to `spawn_multiproof` otherwise.
417    fn spawn_multiproof_task(&mut self, input: PendingMultiproofTask<Factory>) {
418        match input {
419            PendingMultiproofTask::Storage(storage_input) => {
420                self.spawn_storage_proof(storage_input);
421            }
422            PendingMultiproofTask::Regular(multiproof_input) => {
423                self.spawn_multiproof(multiproof_input);
424            }
425        }
426    }
427
428    /// Spawns a single storage proof calculation task.
429    fn spawn_storage_proof(&mut self, storage_multiproof_input: StorageMultiproofInput<Factory>) {
430        let StorageMultiproofInput {
431            config,
432            source,
433            hashed_state_update,
434            hashed_address,
435            proof_targets,
436            proof_sequence_number,
437            state_root_message_sender,
438            multi_added_removed_keys,
439        } = storage_multiproof_input;
440
441        let storage_proof_task_handle = self.storage_proof_task_handle.clone();
442
443        self.executor.spawn_blocking(move || {
444            let storage_targets = proof_targets.len();
445
446            trace!(
447                target: "engine::root",
448                proof_sequence_number,
449                ?proof_targets,
450                storage_targets,
451                "Starting dedicated storage proof calculation",
452            );
453            let start = Instant::now();
454            let result = ParallelProof::new(
455                config.consistent_view,
456                config.nodes_sorted,
457                config.state_sorted,
458                config.prefix_sets,
459                storage_proof_task_handle.clone(),
460            )
461            .with_branch_node_masks(true)
462            .with_multi_added_removed_keys(Some(multi_added_removed_keys))
463            .decoded_storage_proof(hashed_address, proof_targets);
464            let elapsed = start.elapsed();
465            trace!(
466                target: "engine::root",
467                proof_sequence_number,
468                ?elapsed,
469                ?source,
470                storage_targets,
471                "Storage multiproofs calculated",
472            );
473
474            match result {
475                Ok(proof) => {
476                    let _ = state_root_message_sender.send(MultiProofMessage::ProofCalculated(
477                        Box::new(ProofCalculated {
478                            sequence_number: proof_sequence_number,
479                            update: SparseTrieUpdate {
480                                state: hashed_state_update,
481                                multiproof: DecodedMultiProof::from_storage_proof(
482                                    hashed_address,
483                                    proof,
484                                ),
485                            },
486                            elapsed,
487                        }),
488                    ));
489                }
490                Err(error) => {
491                    let _ = state_root_message_sender
492                        .send(MultiProofMessage::ProofCalculationError(error.into()));
493                }
494            }
495        });
496
497        self.inflight += 1;
498        self.metrics.inflight_multiproofs_histogram.record(self.inflight as f64);
499    }
500
501    /// Spawns a single multiproof calculation task.
502    fn spawn_multiproof(&mut self, multiproof_input: MultiproofInput<Factory>) {
503        let MultiproofInput {
504            config,
505            source,
506            hashed_state_update,
507            proof_targets,
508            proof_sequence_number,
509            state_root_message_sender,
510            multi_added_removed_keys,
511        } = multiproof_input;
512        let storage_proof_task_handle = self.storage_proof_task_handle.clone();
513
514        self.executor.spawn_blocking(move || {
515            let account_targets = proof_targets.len();
516            let storage_targets = proof_targets.values().map(|slots| slots.len()).sum::<usize>();
517
518            trace!(
519                target: "engine::root",
520                proof_sequence_number,
521                ?proof_targets,
522                account_targets,
523                storage_targets,
524                ?source,
525                "Starting multiproof calculation",
526            );
527
528            let start = Instant::now();
529            let result = ParallelProof::new(
530                config.consistent_view,
531                config.nodes_sorted,
532                config.state_sorted,
533                config.prefix_sets,
534                storage_proof_task_handle.clone(),
535            )
536            .with_branch_node_masks(true)
537            .with_multi_added_removed_keys(multi_added_removed_keys)
538            .decoded_multiproof(proof_targets);
539            let elapsed = start.elapsed();
540            trace!(
541                target: "engine::root",
542                proof_sequence_number,
543                ?elapsed,
544                ?source,
545                account_targets,
546                storage_targets,
547                "Multiproof calculated",
548            );
549
550            match result {
551                Ok(proof) => {
552                    let _ = state_root_message_sender.send(MultiProofMessage::ProofCalculated(
553                        Box::new(ProofCalculated {
554                            sequence_number: proof_sequence_number,
555                            update: SparseTrieUpdate {
556                                state: hashed_state_update,
557                                multiproof: proof,
558                            },
559                            elapsed,
560                        }),
561                    ));
562                }
563                Err(error) => {
564                    let _ = state_root_message_sender
565                        .send(MultiProofMessage::ProofCalculationError(error.into()));
566                }
567            }
568        });
569
570        self.inflight += 1;
571        self.metrics.inflight_multiproofs_histogram.record(self.inflight as f64);
572    }
573}
574
575#[derive(Metrics, Clone)]
576#[metrics(scope = "tree.root")]
577pub(crate) struct MultiProofTaskMetrics {
578    /// Histogram of inflight multiproofs.
579    pub inflight_multiproofs_histogram: Histogram,
580    /// Histogram of pending multiproofs.
581    pub pending_multiproofs_histogram: Histogram,
582
583    /// Histogram of the number of prefetch proof target accounts.
584    pub prefetch_proof_targets_accounts_histogram: Histogram,
585    /// Histogram of the number of prefetch proof target storages.
586    pub prefetch_proof_targets_storages_histogram: Histogram,
587    /// Histogram of the number of prefetch proof target chunks.
588    pub prefetch_proof_chunks_histogram: Histogram,
589
590    /// Histogram of the number of state update proof target accounts.
591    pub state_update_proof_targets_accounts_histogram: Histogram,
592    /// Histogram of the number of state update proof target storages.
593    pub state_update_proof_targets_storages_histogram: Histogram,
594    /// Histogram of the number of state update proof target chunks.
595    pub state_update_proof_chunks_histogram: Histogram,
596
597    /// Histogram of proof calculation durations.
598    pub proof_calculation_duration_histogram: Histogram,
599
600    /// Histogram of sparse trie update durations.
601    pub sparse_trie_update_duration_histogram: Histogram,
602    /// Histogram of sparse trie final update durations.
603    pub sparse_trie_final_update_duration_histogram: Histogram,
604    /// Histogram of sparse trie total durations.
605    pub sparse_trie_total_duration_histogram: Histogram,
606
607    /// Histogram of state updates received.
608    pub state_updates_received_histogram: Histogram,
609    /// Histogram of proofs processed.
610    pub proofs_processed_histogram: Histogram,
611    /// Histogram of total time spent in the multiproof task.
612    pub multiproof_task_total_duration_histogram: Histogram,
613    /// Total time spent waiting for the first state update or prefetch request.
614    pub first_update_wait_time_histogram: Histogram,
615    /// Total time spent waiting for the last proof result.
616    pub last_proof_wait_time_histogram: Histogram,
617}
618
619/// Standalone task that receives a transaction state stream and updates relevant
620/// data structures to calculate state root.
621///
622/// It is responsible of  initializing a blinded sparse trie and subscribe to
623/// transaction state stream. As it receives transaction execution results, it
624/// fetches the proofs for relevant accounts from the database and reveal them
625/// to the tree.
626/// Then it updates relevant leaves according to the result of the transaction.
627/// This feeds updates to the sparse trie task.
628#[derive(Debug)]
629pub(super) struct MultiProofTask<Factory: DatabaseProviderFactory> {
630    /// Task configuration.
631    config: MultiProofConfig<Factory>,
632    /// Receiver for state root related messages.
633    rx: Receiver<MultiProofMessage>,
634    /// Sender for state root related messages.
635    tx: Sender<MultiProofMessage>,
636    /// Sender for state updates emitted by this type.
637    to_sparse_trie: Sender<SparseTrieUpdate>,
638    /// Proof targets that have been already fetched.
639    fetched_proof_targets: MultiProofTargets,
640    /// Tracks keys which have been added and removed throughout the entire block.
641    multi_added_removed_keys: MultiAddedRemovedKeys,
642    /// Proof sequencing handler.
643    proof_sequencer: ProofSequencer,
644    /// Manages calculation of multiproofs.
645    multiproof_manager: MultiproofManager<Factory>,
646    /// multi proof task metrics
647    metrics: MultiProofTaskMetrics,
648}
649
650impl<Factory> MultiProofTask<Factory>
651where
652    Factory: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
653{
654    /// Creates a new multi proof task with the unified message channel
655    pub(super) fn new(
656        config: MultiProofConfig<Factory>,
657        executor: WorkloadExecutor,
658        proof_task_handle: ProofTaskManagerHandle<FactoryTx<Factory>>,
659        to_sparse_trie: Sender<SparseTrieUpdate>,
660        max_concurrency: usize,
661    ) -> Self {
662        let (tx, rx) = channel();
663        let metrics = MultiProofTaskMetrics::default();
664
665        Self {
666            config,
667            rx,
668            tx,
669            to_sparse_trie,
670            fetched_proof_targets: Default::default(),
671            multi_added_removed_keys: MultiAddedRemovedKeys::new(),
672            proof_sequencer: ProofSequencer::default(),
673            multiproof_manager: MultiproofManager::new(
674                executor,
675                metrics.clone(),
676                proof_task_handle,
677                max_concurrency,
678            ),
679            metrics,
680        }
681    }
682
683    /// Returns a [`Sender`] that can be used to send arbitrary [`MultiProofMessage`]s to this task.
684    pub(super) fn state_root_message_sender(&self) -> Sender<MultiProofMessage> {
685        self.tx.clone()
686    }
687
688    /// Handles request for proof prefetch.
689    ///
690    /// Returns a number of proofs that were spawned.
691    fn on_prefetch_proof(&mut self, targets: MultiProofTargets) -> u64 {
692        let proof_targets = self.get_prefetch_proof_targets(targets);
693        self.fetched_proof_targets.extend_ref(&proof_targets);
694
695        // Make sure all target accounts have an `AddedRemovedKeySet` in the
696        // [`MultiAddedRemovedKeys`]. Even if there are not any known removed keys for the account,
697        // we still want to optimistically fetch extension children for the leaf addition case.
698        self.multi_added_removed_keys.touch_accounts(proof_targets.keys().copied());
699
700        // Clone+Arc MultiAddedRemovedKeys for sharing with the spawned multiproof tasks
701        let multi_added_removed_keys = Arc::new(self.multi_added_removed_keys.clone());
702
703        self.metrics.prefetch_proof_targets_accounts_histogram.record(proof_targets.len() as f64);
704        self.metrics
705            .prefetch_proof_targets_storages_histogram
706            .record(proof_targets.values().map(|slots| slots.len()).sum::<usize>() as f64);
707
708        // Process proof targets in chunks.
709        let mut chunks = 0;
710        for proof_targets_chunk in proof_targets.chunks(MULTIPROOF_TARGETS_CHUNK_SIZE) {
711            self.multiproof_manager.spawn_or_queue(
712                MultiproofInput {
713                    config: self.config.clone(),
714                    source: None,
715                    hashed_state_update: Default::default(),
716                    proof_targets: proof_targets_chunk,
717                    proof_sequence_number: self.proof_sequencer.next_sequence(),
718                    state_root_message_sender: self.tx.clone(),
719                    multi_added_removed_keys: Some(multi_added_removed_keys.clone()),
720                }
721                .into(),
722            );
723            chunks += 1;
724        }
725        self.metrics.prefetch_proof_chunks_histogram.record(chunks as f64);
726
727        chunks
728    }
729
730    // Returns true if all state updates finished and all proofs processed.
731    fn is_done(
732        &self,
733        proofs_processed: u64,
734        state_update_proofs_requested: u64,
735        prefetch_proofs_requested: u64,
736        updates_finished: bool,
737    ) -> bool {
738        let all_proofs_processed =
739            proofs_processed >= state_update_proofs_requested + prefetch_proofs_requested;
740        let no_pending = !self.proof_sequencer.has_pending();
741        debug!(
742            target: "engine::root",
743            proofs_processed,
744            state_update_proofs_requested,
745            prefetch_proofs_requested,
746            no_pending,
747            updates_finished,
748            "Checking end condition"
749        );
750        all_proofs_processed && no_pending && updates_finished
751    }
752
753    /// Calls `get_proof_targets` with existing proof targets for prefetching.
754    fn get_prefetch_proof_targets(&self, mut targets: MultiProofTargets) -> MultiProofTargets {
755        // Here we want to filter out any targets that are already fetched
756        //
757        // This means we need to remove any storage slots that have already been fetched
758        let mut duplicates = 0;
759
760        // First remove all storage targets that are subsets of already fetched storage slots
761        targets.retain(|hashed_address, target_storage| {
762            let keep = self
763                .fetched_proof_targets
764                .get(hashed_address)
765                // do NOT remove if None, because that means the account has not been fetched yet
766                .is_none_or(|fetched_storage| {
767                    // remove if a subset
768                    !target_storage.is_subset(fetched_storage)
769                });
770
771            if !keep {
772                duplicates += target_storage.len();
773            }
774
775            keep
776        });
777
778        // For all non-subset remaining targets, we have to calculate the difference
779        for (hashed_address, target_storage) in targets.deref_mut() {
780            let Some(fetched_storage) = self.fetched_proof_targets.get(hashed_address) else {
781                // this means the account has not been fetched yet, so we must fetch everything
782                // associated with this account
783                continue
784            };
785
786            let prev_target_storage_len = target_storage.len();
787
788            // keep only the storage slots that have not been fetched yet
789            //
790            // we already removed subsets, so this should only remove duplicates
791            target_storage.retain(|slot| !fetched_storage.contains(slot));
792
793            duplicates += prev_target_storage_len - target_storage.len();
794        }
795
796        if duplicates > 0 {
797            trace!(target: "engine::root", duplicates, "Removed duplicate prefetch proof targets");
798        }
799
800        targets
801    }
802
803    /// Handles state updates.
804    ///
805    /// Returns a number of proofs that were spawned.
806    fn on_state_update(&mut self, source: StateChangeSource, update: EvmState) -> u64 {
807        let hashed_state_update = evm_state_to_hashed_post_state(update);
808
809        // Update removed keys based on the state update.
810        self.multi_added_removed_keys.update_with_state(&hashed_state_update);
811
812        // Split the state update into already fetched and not fetched according to the proof
813        // targets.
814        let (fetched_state_update, not_fetched_state_update) = hashed_state_update
815            .partition_by_targets(&self.fetched_proof_targets, &self.multi_added_removed_keys);
816
817        let mut state_updates = 0;
818        // If there are any accounts or storage slots that we already fetched the proofs for,
819        // send them immediately, as they don't require spawning any additional multiproofs.
820        if !fetched_state_update.is_empty() {
821            let _ = self.tx.send(MultiProofMessage::EmptyProof {
822                sequence_number: self.proof_sequencer.next_sequence(),
823                state: fetched_state_update,
824            });
825            state_updates += 1;
826        }
827
828        // Clone+Arc MultiAddedRemovedKeys for sharing with the spawned multiproof tasks
829        let multi_added_removed_keys = Arc::new(self.multi_added_removed_keys.clone());
830
831        // Process state updates in chunks.
832        let mut chunks = 0;
833        let mut spawned_proof_targets = MultiProofTargets::default();
834        for chunk in not_fetched_state_update.chunks(MULTIPROOF_TARGETS_CHUNK_SIZE) {
835            let proof_targets =
836                get_proof_targets(&chunk, &self.fetched_proof_targets, &multi_added_removed_keys);
837            spawned_proof_targets.extend_ref(&proof_targets);
838
839            self.multiproof_manager.spawn_or_queue(
840                MultiproofInput {
841                    config: self.config.clone(),
842                    source: Some(source),
843                    hashed_state_update: chunk,
844                    proof_targets,
845                    proof_sequence_number: self.proof_sequencer.next_sequence(),
846                    state_root_message_sender: self.tx.clone(),
847                    multi_added_removed_keys: Some(multi_added_removed_keys.clone()),
848                }
849                .into(),
850            );
851            chunks += 1;
852        }
853
854        self.metrics
855            .state_update_proof_targets_accounts_histogram
856            .record(spawned_proof_targets.len() as f64);
857        self.metrics
858            .state_update_proof_targets_storages_histogram
859            .record(spawned_proof_targets.values().map(|slots| slots.len()).sum::<usize>() as f64);
860        self.metrics.state_update_proof_chunks_histogram.record(chunks as f64);
861
862        self.fetched_proof_targets.extend(spawned_proof_targets);
863
864        state_updates + chunks
865    }
866
867    /// Handler for new proof calculated, aggregates all the existing sequential proofs.
868    fn on_proof(
869        &mut self,
870        sequence_number: u64,
871        update: SparseTrieUpdate,
872    ) -> Option<SparseTrieUpdate> {
873        let ready_proofs = self.proof_sequencer.add_proof(sequence_number, update);
874
875        ready_proofs
876            .into_iter()
877            // Merge all ready proofs and state updates
878            .reduce(|mut acc_update, update| {
879                acc_update.extend(update);
880                acc_update
881            })
882            // Return None if the resulting proof is empty
883            .filter(|proof| !proof.is_empty())
884    }
885
886    /// Starts the main loop that handles all incoming messages, fetches proofs, applies them to the
887    /// sparse trie, updates the sparse trie, and eventually returns the state root.
888    ///
889    /// The lifecycle is the following:
890    /// 1. Either [`MultiProofMessage::PrefetchProofs`] or [`MultiProofMessage::StateUpdate`] is
891    ///    received from the engine.
892    ///    * For [`MultiProofMessage::StateUpdate`], the state update is hashed with
893    ///      [`evm_state_to_hashed_post_state`], and then (proof targets)[`MultiProofTargets`] are
894    ///      extracted with [`get_proof_targets`].
895    ///    * For both messages, proof targets are deduplicated according to `fetched_proof_targets`,
896    ///      so that the proofs for accounts and storage slots that were already fetched are not
897    ///      requested again.
898    /// 2. Using the proof targets, a new multiproof is calculated using
899    ///    [`MultiproofManager::spawn_or_queue`].
900    ///    * If the list of proof targets is empty, the [`MultiProofMessage::EmptyProof`] message is
901    ///      sent back to this task along with the original state update.
902    ///    * Otherwise, the multiproof is calculated and the [`MultiProofMessage::ProofCalculated`]
903    ///      message is sent back to this task along with the resulting multiproof, proof targets
904    ///      and original state update.
905    /// 3. Either [`MultiProofMessage::EmptyProof`] or [`MultiProofMessage::ProofCalculated`] is
906    ///    received.
907    ///    * The multiproof is added to the (proof sequencer)[`ProofSequencer`].
908    ///    * If the proof sequencer has a contiguous sequence of multiproofs in the same order as
909    ///      state updates arrived (i.e. transaction order), such sequence is returned.
910    /// 4. Once there's a sequence of contiguous multiproofs along with the proof targets and state
911    ///    updates associated with them, a [`SparseTrieUpdate`] is generated and sent to the sparse
912    ///    trie task.
913    /// 5. Steps above are repeated until this task receives a
914    ///    [`MultiProofMessage::FinishedStateUpdates`].
915    ///    * Once this message is received, on every [`MultiProofMessage::EmptyProof`] and
916    ///      [`MultiProofMessage::ProofCalculated`] message, we check if there are any proofs are
917    ///      currently being calculated, or if there are any pending proofs in the proof sequencer
918    ///      left to be revealed by checking the pending tasks.
919    /// 6. This task exits after all pending proofs are processed.
920    pub(crate) fn run(mut self) {
921        // TODO convert those into fields
922        let mut prefetch_proofs_requested = 0;
923        let mut state_update_proofs_requested = 0;
924        let mut proofs_processed = 0;
925
926        let mut updates_finished = false;
927
928        // Timestamp before the first state update or prefetch was received
929        let start = Instant::now();
930
931        // Timestamp when the first state update or prefetch was received
932        let mut first_update_time = None;
933        // Timestamp when state updates have finished
934        let mut updates_finished_time = None;
935
936        loop {
937            trace!(target: "engine::root", "entering main channel receiving loop");
938            match self.rx.recv() {
939                Ok(message) => match message {
940                    MultiProofMessage::PrefetchProofs(targets) => {
941                        trace!(target: "engine::root", "processing MultiProofMessage::PrefetchProofs");
942                        if first_update_time.is_none() {
943                            // record the wait time
944                            self.metrics
945                                .first_update_wait_time_histogram
946                                .record(start.elapsed().as_secs_f64());
947                            first_update_time = Some(Instant::now());
948                            debug!(target: "engine::root", "Started state root calculation");
949                        }
950
951                        let account_targets = targets.len();
952                        let storage_targets =
953                            targets.values().map(|slots| slots.len()).sum::<usize>();
954                        prefetch_proofs_requested += self.on_prefetch_proof(targets);
955                        debug!(
956                            target: "engine::root",
957                            account_targets,
958                            storage_targets,
959                            prefetch_proofs_requested,
960                            "Prefetching proofs"
961                        );
962                    }
963                    MultiProofMessage::StateUpdate(source, update) => {
964                        trace!(target: "engine::root", "processing MultiProofMessage::StateUpdate");
965                        if first_update_time.is_none() {
966                            // record the wait time
967                            self.metrics
968                                .first_update_wait_time_histogram
969                                .record(start.elapsed().as_secs_f64());
970                            first_update_time = Some(Instant::now());
971                            debug!(target: "engine::root", "Started state root calculation");
972                        }
973
974                        let len = update.len();
975                        state_update_proofs_requested += self.on_state_update(source, update);
976                        debug!(
977                            target: "engine::root",
978                            ?source,
979                            len,
980                            ?state_update_proofs_requested,
981                            "Received new state update"
982                        );
983                    }
984                    MultiProofMessage::FinishedStateUpdates => {
985                        trace!(target: "engine::root", "processing MultiProofMessage::FinishedStateUpdates");
986                        updates_finished = true;
987                        updates_finished_time = Some(Instant::now());
988                        if self.is_done(
989                            proofs_processed,
990                            state_update_proofs_requested,
991                            prefetch_proofs_requested,
992                            updates_finished,
993                        ) {
994                            debug!(
995                                target: "engine::root",
996                                "State updates finished and all proofs processed, ending calculation"
997                            );
998                            break
999                        }
1000                    }
1001                    MultiProofMessage::EmptyProof { sequence_number, state } => {
1002                        trace!(target: "engine::root", "processing MultiProofMessage::EmptyProof");
1003
1004                        proofs_processed += 1;
1005
1006                        if let Some(combined_update) = self.on_proof(
1007                            sequence_number,
1008                            SparseTrieUpdate { state, multiproof: Default::default() },
1009                        ) {
1010                            let _ = self.to_sparse_trie.send(combined_update);
1011                        }
1012
1013                        if self.is_done(
1014                            proofs_processed,
1015                            state_update_proofs_requested,
1016                            prefetch_proofs_requested,
1017                            updates_finished,
1018                        ) {
1019                            debug!(
1020                                target: "engine::root",
1021                                "State updates finished and all proofs processed, ending calculation"
1022                            );
1023                            break
1024                        }
1025                    }
1026                    MultiProofMessage::ProofCalculated(proof_calculated) => {
1027                        trace!(target: "engine::root", "processing
1028        MultiProofMessage::ProofCalculated");
1029
1030                        // we increment proofs_processed for both state updates and prefetches,
1031                        // because both are used for the root termination condition.
1032                        proofs_processed += 1;
1033
1034                        self.metrics
1035                            .proof_calculation_duration_histogram
1036                            .record(proof_calculated.elapsed);
1037
1038                        debug!(
1039                            target: "engine::root",
1040                            sequence = proof_calculated.sequence_number,
1041                            total_proofs = proofs_processed,
1042                            "Processing calculated proof"
1043                        );
1044
1045                        self.multiproof_manager.on_calculation_complete();
1046
1047                        if let Some(combined_update) =
1048                            self.on_proof(proof_calculated.sequence_number, proof_calculated.update)
1049                        {
1050                            let _ = self.to_sparse_trie.send(combined_update);
1051                        }
1052
1053                        if self.is_done(
1054                            proofs_processed,
1055                            state_update_proofs_requested,
1056                            prefetch_proofs_requested,
1057                            updates_finished,
1058                        ) {
1059                            debug!(
1060                                target: "engine::root",
1061                                "State updates finished and all proofs processed, ending calculation");
1062                            break
1063                        }
1064                    }
1065                    MultiProofMessage::ProofCalculationError(err) => {
1066                        error!(
1067                            target: "engine::root",
1068                            ?err,
1069                            "proof calculation error"
1070                        );
1071                        return
1072                    }
1073                },
1074                Err(_) => {
1075                    // this means our internal message channel is closed, which shouldn't happen
1076                    // in normal operation since we hold both ends
1077                    error!(
1078                        target: "engine::root",
1079                        "Internal message channel closed unexpectedly"
1080                    );
1081                }
1082            }
1083        }
1084
1085        debug!(
1086            target: "engine::root",
1087            total_updates = state_update_proofs_requested,
1088            total_proofs = proofs_processed,
1089            total_time = ?first_update_time.map(|t|t.elapsed()),
1090            time_since_updates_finished = ?updates_finished_time.map(|t|t.elapsed()),
1091            "All proofs processed, ending calculation"
1092        );
1093
1094        // update total metrics on finish
1095        self.metrics.state_updates_received_histogram.record(state_update_proofs_requested as f64);
1096        self.metrics.proofs_processed_histogram.record(proofs_processed as f64);
1097        if let Some(total_time) = first_update_time.map(|t| t.elapsed()) {
1098            self.metrics.multiproof_task_total_duration_histogram.record(total_time);
1099        }
1100
1101        if let Some(updates_finished_time) = updates_finished_time {
1102            self.metrics
1103                .last_proof_wait_time_histogram
1104                .record(updates_finished_time.elapsed().as_secs_f64());
1105        }
1106    }
1107}
1108
1109/// Returns accounts only with those storages that were not already fetched, and
1110/// if there are no such storages and the account itself was already fetched, the
1111/// account shouldn't be included.
1112fn get_proof_targets(
1113    state_update: &HashedPostState,
1114    fetched_proof_targets: &MultiProofTargets,
1115    multi_added_removed_keys: &MultiAddedRemovedKeys,
1116) -> MultiProofTargets {
1117    let mut targets = MultiProofTargets::default();
1118
1119    // first collect all new accounts (not previously fetched)
1120    for &hashed_address in state_update.accounts.keys() {
1121        if !fetched_proof_targets.contains_key(&hashed_address) {
1122            targets.insert(hashed_address, HashSet::default());
1123        }
1124    }
1125
1126    // then process storage slots for all accounts in the state update
1127    for (hashed_address, storage) in &state_update.storages {
1128        let fetched = fetched_proof_targets.get(hashed_address);
1129        let storage_added_removed_keys = multi_added_removed_keys.get_storage(hashed_address);
1130        let mut changed_slots = storage
1131            .storage
1132            .keys()
1133            .filter(|slot| {
1134                !fetched.is_some_and(|f| f.contains(*slot)) ||
1135                    storage_added_removed_keys.is_some_and(|k| k.is_removed(slot))
1136            })
1137            .peekable();
1138
1139        // If the storage is wiped, we still need to fetch the account proof.
1140        if storage.wiped && fetched.is_none() {
1141            targets.entry(*hashed_address).or_default();
1142        }
1143
1144        if changed_slots.peek().is_some() {
1145            targets.entry(*hashed_address).or_default().extend(changed_slots);
1146        }
1147    }
1148
1149    targets
1150}
1151
1152#[cfg(test)]
1153mod tests {
1154    use super::*;
1155    use alloy_primitives::map::B256Set;
1156    use reth_provider::{providers::ConsistentDbView, test_utils::create_test_provider_factory};
1157    use reth_trie::{MultiProof, TrieInput};
1158    use reth_trie_parallel::proof_task::{ProofTaskCtx, ProofTaskManager};
1159    use revm_primitives::{B256, U256};
1160    use std::sync::Arc;
1161
1162    fn create_state_root_config<F>(factory: F, input: TrieInput) -> MultiProofConfig<F>
1163    where
1164        F: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
1165    {
1166        let consistent_view = ConsistentDbView::new(factory, None);
1167        let nodes_sorted = Arc::new(input.nodes.clone().into_sorted());
1168        let state_sorted = Arc::new(input.state.clone().into_sorted());
1169        let prefix_sets = Arc::new(input.prefix_sets);
1170
1171        MultiProofConfig { consistent_view, nodes_sorted, state_sorted, prefix_sets }
1172    }
1173
1174    fn create_test_state_root_task<F>(factory: F) -> MultiProofTask<F>
1175    where
1176        F: DatabaseProviderFactory<Provider: BlockReader> + Clone + 'static,
1177    {
1178        let executor = WorkloadExecutor::default();
1179        let config = create_state_root_config(factory, TrieInput::default());
1180        let task_ctx = ProofTaskCtx::new(
1181            config.nodes_sorted.clone(),
1182            config.state_sorted.clone(),
1183            config.prefix_sets.clone(),
1184        );
1185        let proof_task = ProofTaskManager::new(
1186            executor.handle().clone(),
1187            config.consistent_view.clone(),
1188            task_ctx,
1189            1,
1190        );
1191        let channel = channel();
1192
1193        MultiProofTask::new(config, executor, proof_task.handle(), channel.0, 1)
1194    }
1195
1196    #[test]
1197    fn test_add_proof_in_sequence() {
1198        let mut sequencer = ProofSequencer::default();
1199        let proof1 = MultiProof::default();
1200        let proof2 = MultiProof::default();
1201        sequencer.next_sequence = 2;
1202
1203        let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof1).unwrap());
1204        assert_eq!(ready.len(), 1);
1205        assert!(!sequencer.has_pending());
1206
1207        let ready = sequencer.add_proof(1, SparseTrieUpdate::from_multiproof(proof2).unwrap());
1208        assert_eq!(ready.len(), 1);
1209        assert!(!sequencer.has_pending());
1210    }
1211
1212    #[test]
1213    fn test_add_proof_out_of_order() {
1214        let mut sequencer = ProofSequencer::default();
1215        let proof1 = MultiProof::default();
1216        let proof2 = MultiProof::default();
1217        let proof3 = MultiProof::default();
1218        sequencer.next_sequence = 3;
1219
1220        let ready = sequencer.add_proof(2, SparseTrieUpdate::from_multiproof(proof3).unwrap());
1221        assert_eq!(ready.len(), 0);
1222        assert!(sequencer.has_pending());
1223
1224        let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof1).unwrap());
1225        assert_eq!(ready.len(), 1);
1226        assert!(sequencer.has_pending());
1227
1228        let ready = sequencer.add_proof(1, SparseTrieUpdate::from_multiproof(proof2).unwrap());
1229        assert_eq!(ready.len(), 2);
1230        assert!(!sequencer.has_pending());
1231    }
1232
1233    #[test]
1234    fn test_add_proof_with_gaps() {
1235        let mut sequencer = ProofSequencer::default();
1236        let proof1 = MultiProof::default();
1237        let proof3 = MultiProof::default();
1238        sequencer.next_sequence = 3;
1239
1240        let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof1).unwrap());
1241        assert_eq!(ready.len(), 1);
1242
1243        let ready = sequencer.add_proof(2, SparseTrieUpdate::from_multiproof(proof3).unwrap());
1244        assert_eq!(ready.len(), 0);
1245        assert!(sequencer.has_pending());
1246    }
1247
1248    #[test]
1249    fn test_add_proof_duplicate_sequence() {
1250        let mut sequencer = ProofSequencer::default();
1251        let proof1 = MultiProof::default();
1252        let proof2 = MultiProof::default();
1253
1254        let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof1).unwrap());
1255        assert_eq!(ready.len(), 1);
1256
1257        let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof2).unwrap());
1258        assert_eq!(ready.len(), 0);
1259        assert!(!sequencer.has_pending());
1260    }
1261
1262    #[test]
1263    fn test_add_proof_batch_processing() {
1264        let mut sequencer = ProofSequencer::default();
1265        let proofs: Vec<_> = (0..5).map(|_| MultiProof::default()).collect();
1266        sequencer.next_sequence = 5;
1267
1268        sequencer.add_proof(4, SparseTrieUpdate::from_multiproof(proofs[4].clone()).unwrap());
1269        sequencer.add_proof(2, SparseTrieUpdate::from_multiproof(proofs[2].clone()).unwrap());
1270        sequencer.add_proof(1, SparseTrieUpdate::from_multiproof(proofs[1].clone()).unwrap());
1271        sequencer.add_proof(3, SparseTrieUpdate::from_multiproof(proofs[3].clone()).unwrap());
1272
1273        let ready =
1274            sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proofs[0].clone()).unwrap());
1275        assert_eq!(ready.len(), 5);
1276        assert!(!sequencer.has_pending());
1277    }
1278
1279    fn create_get_proof_targets_state() -> HashedPostState {
1280        let mut state = HashedPostState::default();
1281
1282        let addr1 = B256::random();
1283        let addr2 = B256::random();
1284        state.accounts.insert(addr1, Some(Default::default()));
1285        state.accounts.insert(addr2, Some(Default::default()));
1286
1287        let mut storage = HashedStorage::default();
1288        let slot1 = B256::random();
1289        let slot2 = B256::random();
1290        storage.storage.insert(slot1, U256::ZERO);
1291        storage.storage.insert(slot2, U256::from(1));
1292        state.storages.insert(addr1, storage);
1293
1294        state
1295    }
1296
1297    #[test]
1298    fn test_get_proof_targets_new_account_targets() {
1299        let state = create_get_proof_targets_state();
1300        let fetched = MultiProofTargets::default();
1301
1302        let targets = get_proof_targets(&state, &fetched, &MultiAddedRemovedKeys::new());
1303
1304        // should return all accounts as targets since nothing was fetched before
1305        assert_eq!(targets.len(), state.accounts.len());
1306        for addr in state.accounts.keys() {
1307            assert!(targets.contains_key(addr));
1308        }
1309    }
1310
1311    #[test]
1312    fn test_get_proof_targets_new_storage_targets() {
1313        let state = create_get_proof_targets_state();
1314        let fetched = MultiProofTargets::default();
1315
1316        let targets = get_proof_targets(&state, &fetched, &MultiAddedRemovedKeys::new());
1317
1318        // verify storage slots are included for accounts with storage
1319        for (addr, storage) in &state.storages {
1320            assert!(targets.contains_key(addr));
1321            let target_slots = &targets[addr];
1322            assert_eq!(target_slots.len(), storage.storage.len());
1323            for slot in storage.storage.keys() {
1324                assert!(target_slots.contains(slot));
1325            }
1326        }
1327    }
1328
1329    #[test]
1330    fn test_get_proof_targets_filter_already_fetched_accounts() {
1331        let state = create_get_proof_targets_state();
1332        let mut fetched = MultiProofTargets::default();
1333
1334        // select an account that has no storage updates
1335        let fetched_addr = state
1336            .accounts
1337            .keys()
1338            .find(|&&addr| !state.storages.contains_key(&addr))
1339            .expect("Should have an account without storage");
1340
1341        // mark the account as already fetched
1342        fetched.insert(*fetched_addr, HashSet::default());
1343
1344        let targets = get_proof_targets(&state, &fetched, &MultiAddedRemovedKeys::new());
1345
1346        // should not include the already fetched account since it has no storage updates
1347        assert!(!targets.contains_key(fetched_addr));
1348        // other accounts should still be included
1349        assert_eq!(targets.len(), state.accounts.len() - 1);
1350    }
1351
1352    #[test]
1353    fn test_get_proof_targets_filter_already_fetched_storage() {
1354        let state = create_get_proof_targets_state();
1355        let mut fetched = MultiProofTargets::default();
1356
1357        // mark one storage slot as already fetched
1358        let (addr, storage) = state.storages.iter().next().unwrap();
1359        let mut fetched_slots = HashSet::default();
1360        let fetched_slot = *storage.storage.keys().next().unwrap();
1361        fetched_slots.insert(fetched_slot);
1362        fetched.insert(*addr, fetched_slots);
1363
1364        let targets = get_proof_targets(&state, &fetched, &MultiAddedRemovedKeys::new());
1365
1366        // should not include the already fetched storage slot
1367        let target_slots = &targets[addr];
1368        assert!(!target_slots.contains(&fetched_slot));
1369        assert_eq!(target_slots.len(), storage.storage.len() - 1);
1370    }
1371
1372    #[test]
1373    fn test_get_proof_targets_empty_state() {
1374        let state = HashedPostState::default();
1375        let fetched = MultiProofTargets::default();
1376
1377        let targets = get_proof_targets(&state, &fetched, &MultiAddedRemovedKeys::new());
1378
1379        assert!(targets.is_empty());
1380    }
1381
1382    #[test]
1383    fn test_get_proof_targets_mixed_fetched_state() {
1384        let mut state = HashedPostState::default();
1385        let mut fetched = MultiProofTargets::default();
1386
1387        let addr1 = B256::random();
1388        let addr2 = B256::random();
1389        let slot1 = B256::random();
1390        let slot2 = B256::random();
1391
1392        state.accounts.insert(addr1, Some(Default::default()));
1393        state.accounts.insert(addr2, Some(Default::default()));
1394
1395        let mut storage = HashedStorage::default();
1396        storage.storage.insert(slot1, U256::ZERO);
1397        storage.storage.insert(slot2, U256::from(1));
1398        state.storages.insert(addr1, storage);
1399
1400        let mut fetched_slots = HashSet::default();
1401        fetched_slots.insert(slot1);
1402        fetched.insert(addr1, fetched_slots);
1403
1404        let targets = get_proof_targets(&state, &fetched, &MultiAddedRemovedKeys::new());
1405
1406        assert!(targets.contains_key(&addr2));
1407        assert!(!targets[&addr1].contains(&slot1));
1408        assert!(targets[&addr1].contains(&slot2));
1409    }
1410
1411    #[test]
1412    fn test_get_proof_targets_unmodified_account_with_storage() {
1413        let mut state = HashedPostState::default();
1414        let fetched = MultiProofTargets::default();
1415
1416        let addr = B256::random();
1417        let slot1 = B256::random();
1418        let slot2 = B256::random();
1419
1420        // don't add the account to state.accounts (simulating unmodified account)
1421        // but add storage updates for this account
1422        let mut storage = HashedStorage::default();
1423        storage.storage.insert(slot1, U256::from(1));
1424        storage.storage.insert(slot2, U256::from(2));
1425        state.storages.insert(addr, storage);
1426
1427        assert!(!state.accounts.contains_key(&addr));
1428        assert!(!fetched.contains_key(&addr));
1429
1430        let targets = get_proof_targets(&state, &fetched, &MultiAddedRemovedKeys::new());
1431
1432        // verify that we still get the storage slots for the unmodified account
1433        assert!(targets.contains_key(&addr));
1434
1435        let target_slots = &targets[&addr];
1436        assert_eq!(target_slots.len(), 2);
1437        assert!(target_slots.contains(&slot1));
1438        assert!(target_slots.contains(&slot2));
1439    }
1440
1441    #[test]
1442    fn test_get_prefetch_proof_targets_no_duplicates() {
1443        let test_provider_factory = create_test_provider_factory();
1444        let mut test_state_root_task = create_test_state_root_task(test_provider_factory);
1445
1446        // populate some targets
1447        let mut targets = MultiProofTargets::default();
1448        let addr1 = B256::random();
1449        let addr2 = B256::random();
1450        let slot1 = B256::random();
1451        let slot2 = B256::random();
1452        targets.insert(addr1, std::iter::once(slot1).collect());
1453        targets.insert(addr2, std::iter::once(slot2).collect());
1454
1455        let prefetch_proof_targets =
1456            test_state_root_task.get_prefetch_proof_targets(targets.clone());
1457
1458        // check that the prefetch proof targets are the same because there are no fetched proof
1459        // targets yet
1460        assert_eq!(prefetch_proof_targets, targets);
1461
1462        // add a different addr and slot to fetched proof targets
1463        let addr3 = B256::random();
1464        let slot3 = B256::random();
1465        test_state_root_task.fetched_proof_targets.insert(addr3, std::iter::once(slot3).collect());
1466
1467        let prefetch_proof_targets =
1468            test_state_root_task.get_prefetch_proof_targets(targets.clone());
1469
1470        // check that the prefetch proof targets are the same because the fetched proof targets
1471        // don't overlap with the prefetch targets
1472        assert_eq!(prefetch_proof_targets, targets);
1473    }
1474
1475    #[test]
1476    fn test_get_prefetch_proof_targets_remove_subset() {
1477        let test_provider_factory = create_test_provider_factory();
1478        let mut test_state_root_task = create_test_state_root_task(test_provider_factory);
1479
1480        // populate some targe
1481        let mut targets = MultiProofTargets::default();
1482        let addr1 = B256::random();
1483        let addr2 = B256::random();
1484        let slot1 = B256::random();
1485        let slot2 = B256::random();
1486        targets.insert(addr1, std::iter::once(slot1).collect());
1487        targets.insert(addr2, std::iter::once(slot2).collect());
1488
1489        // add a subset of the first target to fetched proof targets
1490        test_state_root_task.fetched_proof_targets.insert(addr1, std::iter::once(slot1).collect());
1491
1492        let prefetch_proof_targets =
1493            test_state_root_task.get_prefetch_proof_targets(targets.clone());
1494
1495        // check that the prefetch proof targets do not include the subset
1496        assert_eq!(prefetch_proof_targets.len(), 1);
1497        assert!(!prefetch_proof_targets.contains_key(&addr1));
1498        assert!(prefetch_proof_targets.contains_key(&addr2));
1499
1500        // now add one more slot to the prefetch targets
1501        let slot3 = B256::random();
1502        targets.get_mut(&addr1).unwrap().insert(slot3);
1503
1504        let prefetch_proof_targets =
1505            test_state_root_task.get_prefetch_proof_targets(targets.clone());
1506
1507        // check that the prefetch proof targets do not include the subset
1508        // but include the new slot
1509        assert_eq!(prefetch_proof_targets.len(), 2);
1510        assert!(prefetch_proof_targets.contains_key(&addr1));
1511        assert_eq!(
1512            *prefetch_proof_targets.get(&addr1).unwrap(),
1513            std::iter::once(slot3).collect::<B256Set>()
1514        );
1515        assert!(prefetch_proof_targets.contains_key(&addr2));
1516        assert_eq!(
1517            *prefetch_proof_targets.get(&addr2).unwrap(),
1518            std::iter::once(slot2).collect::<B256Set>()
1519        );
1520    }
1521
1522    #[test]
1523    fn test_get_proof_targets_with_removed_storage_keys() {
1524        let mut state = HashedPostState::default();
1525        let mut fetched = MultiProofTargets::default();
1526        let mut multi_added_removed_keys = MultiAddedRemovedKeys::new();
1527
1528        let addr = B256::random();
1529        let slot1 = B256::random();
1530        let slot2 = B256::random();
1531
1532        // add account to state
1533        state.accounts.insert(addr, Some(Default::default()));
1534
1535        // add storage updates
1536        let mut storage = HashedStorage::default();
1537        storage.storage.insert(slot1, U256::from(100));
1538        storage.storage.insert(slot2, U256::from(200));
1539        state.storages.insert(addr, storage);
1540
1541        // mark slot1 as already fetched
1542        let mut fetched_slots = HashSet::default();
1543        fetched_slots.insert(slot1);
1544        fetched.insert(addr, fetched_slots);
1545
1546        // update multi_added_removed_keys to mark slot1 as removed
1547        let mut removed_state = HashedPostState::default();
1548        let mut removed_storage = HashedStorage::default();
1549        removed_storage.storage.insert(slot1, U256::ZERO); // U256::ZERO marks as removed
1550        removed_state.storages.insert(addr, removed_storage);
1551        multi_added_removed_keys.update_with_state(&removed_state);
1552
1553        let targets = get_proof_targets(&state, &fetched, &multi_added_removed_keys);
1554
1555        // slot1 should be included despite being fetched, because it's marked as removed
1556        assert!(targets.contains_key(&addr));
1557        let target_slots = &targets[&addr];
1558        assert_eq!(target_slots.len(), 2);
1559        assert!(target_slots.contains(&slot1)); // included because it's removed
1560        assert!(target_slots.contains(&slot2)); // included because it's not fetched
1561    }
1562
1563    #[test]
1564    fn test_get_proof_targets_with_wiped_storage() {
1565        let mut state = HashedPostState::default();
1566        let fetched = MultiProofTargets::default();
1567        let multi_added_removed_keys = MultiAddedRemovedKeys::new();
1568
1569        let addr = B256::random();
1570        let slot1 = B256::random();
1571
1572        // add account to state
1573        state.accounts.insert(addr, Some(Default::default()));
1574
1575        // add wiped storage
1576        let mut storage = HashedStorage::new(true);
1577        storage.storage.insert(slot1, U256::from(100));
1578        state.storages.insert(addr, storage);
1579
1580        let targets = get_proof_targets(&state, &fetched, &multi_added_removed_keys);
1581
1582        // account should be included because storage is wiped and account wasn't fetched
1583        assert!(targets.contains_key(&addr));
1584        let target_slots = &targets[&addr];
1585        assert_eq!(target_slots.len(), 1);
1586        assert!(target_slots.contains(&slot1));
1587    }
1588
1589    #[test]
1590    fn test_get_proof_targets_removed_keys_not_in_state_update() {
1591        let mut state = HashedPostState::default();
1592        let mut fetched = MultiProofTargets::default();
1593        let mut multi_added_removed_keys = MultiAddedRemovedKeys::new();
1594
1595        let addr = B256::random();
1596        let slot1 = B256::random();
1597        let slot2 = B256::random();
1598        let slot3 = B256::random();
1599
1600        // add account to state
1601        state.accounts.insert(addr, Some(Default::default()));
1602
1603        // add storage updates for slot1 and slot2 only
1604        let mut storage = HashedStorage::default();
1605        storage.storage.insert(slot1, U256::from(100));
1606        storage.storage.insert(slot2, U256::from(200));
1607        state.storages.insert(addr, storage);
1608
1609        // mark all slots as already fetched
1610        let mut fetched_slots = HashSet::default();
1611        fetched_slots.insert(slot1);
1612        fetched_slots.insert(slot2);
1613        fetched_slots.insert(slot3); // slot3 is fetched but not in state update
1614        fetched.insert(addr, fetched_slots);
1615
1616        // mark slot3 as removed (even though it's not in the state update)
1617        let mut removed_state = HashedPostState::default();
1618        let mut removed_storage = HashedStorage::default();
1619        removed_storage.storage.insert(slot3, U256::ZERO);
1620        removed_state.storages.insert(addr, removed_storage);
1621        multi_added_removed_keys.update_with_state(&removed_state);
1622
1623        let targets = get_proof_targets(&state, &fetched, &multi_added_removed_keys);
1624
1625        // only slots in the state update can be included, so slot3 should not appear
1626        assert!(!targets.contains_key(&addr));
1627    }
1628}