reth_engine_tree/tree/payload_processor/
multiproof.rs

1//! Multiproof task related functionality.
2
3use alloy_evm::block::StateChangeSource;
4use alloy_primitives::{
5    keccak256,
6    map::{B256Set, HashSet},
7    B256,
8};
9use crossbeam_channel::{unbounded, Receiver as CrossbeamReceiver, Sender as CrossbeamSender};
10use dashmap::DashMap;
11use derive_more::derive::Deref;
12use metrics::{Gauge, Histogram};
13use reth_metrics::Metrics;
14use reth_revm::state::EvmState;
15use reth_trie::{
16    added_removed_keys::MultiAddedRemovedKeys, DecodedMultiProof, HashedPostState, HashedStorage,
17    MultiProofTargets,
18};
19use reth_trie_parallel::{
20    proof::ParallelProof,
21    proof_task::{
22        AccountMultiproofInput, ProofResultContext, ProofResultMessage, ProofWorkerHandle,
23        StorageProofInput,
24    },
25};
26use std::{collections::BTreeMap, mem, ops::DerefMut, sync::Arc, time::Instant};
27use tracing::{debug, error, instrument, trace};
28
29/// Maximum number of targets to batch together for prefetch batching.
30/// Prefetches are just proof requests (no state merging), so we allow a higher cap than state
31/// updates
32const PREFETCH_MAX_BATCH_TARGETS: usize = 512;
33
34/// Maximum number of prefetch messages to batch together.
35/// Prevents excessive batching even with small messages.
36const PREFETCH_MAX_BATCH_MESSAGES: usize = 16;
37
38/// Maximum number of targets to batch together for state updates.
39/// Lower than prefetch because state updates require additional processing (hashing, state
40/// partitioning) before dispatch.
41const STATE_UPDATE_MAX_BATCH_TARGETS: usize = 64;
42
43/// Preallocation hint for state update batching to avoid repeated reallocations on small bursts.
44const STATE_UPDATE_BATCH_PREALLOC: usize = 16;
45
46/// The default max targets, for limiting the number of account and storage proof targets to be
47/// fetched by a single worker. If exceeded, chunking is forced regardless of worker availability.
48const DEFAULT_MAX_TARGETS_FOR_CHUNKING: usize = 300;
49
50/// A trie update that can be applied to sparse trie alongside the proofs for touched parts of the
51/// state.
52#[derive(Default, Debug)]
53pub struct SparseTrieUpdate {
54    /// The state update that was used to calculate the proof
55    pub(crate) state: HashedPostState,
56    /// The calculated multiproof
57    pub(crate) multiproof: DecodedMultiProof,
58}
59
60impl SparseTrieUpdate {
61    /// Returns true if the update is empty.
62    pub(super) fn is_empty(&self) -> bool {
63        self.state.is_empty() && self.multiproof.is_empty()
64    }
65
66    /// Construct update from multiproof.
67    #[cfg(test)]
68    pub(super) fn from_multiproof(multiproof: reth_trie::MultiProof) -> alloy_rlp::Result<Self> {
69        Ok(Self { multiproof: multiproof.try_into()?, ..Default::default() })
70    }
71
72    /// Extend update with contents of the other.
73    pub(super) fn extend(&mut self, other: Self) {
74        self.state.extend(other.state);
75        self.multiproof.extend(other.multiproof);
76    }
77}
78
79/// Messages used internally by the multi proof task.
80#[derive(Debug)]
81pub(super) enum MultiProofMessage {
82    /// Prefetch proof targets
83    PrefetchProofs(MultiProofTargets),
84    /// New state update from transaction execution with its source
85    StateUpdate(StateChangeSource, EvmState),
86    /// State update that can be applied to the sparse trie without any new proofs.
87    ///
88    /// It can be the case when all accounts and storage slots from the state update were already
89    /// fetched and revealed.
90    EmptyProof {
91        /// The index of this proof in the sequence of state updates
92        sequence_number: u64,
93        /// The state update that was used to calculate the proof
94        state: HashedPostState,
95    },
96    /// Signals state update stream end.
97    ///
98    /// This is triggered by block execution, indicating that no additional state updates are
99    /// expected.
100    FinishedStateUpdates,
101}
102
103/// Handle to track proof calculation ordering.
104#[derive(Debug, Default)]
105struct ProofSequencer {
106    /// The next proof sequence number to be produced.
107    next_sequence: u64,
108    /// The next sequence number expected to be delivered.
109    next_to_deliver: u64,
110    /// Buffer for out-of-order proofs and corresponding state updates
111    pending_proofs: BTreeMap<u64, SparseTrieUpdate>,
112}
113
114impl ProofSequencer {
115    /// Gets the next sequence number and increments the counter
116    const fn next_sequence(&mut self) -> u64 {
117        let seq = self.next_sequence;
118        self.next_sequence += 1;
119        seq
120    }
121
122    /// Adds a proof with the corresponding state update and returns all sequential proofs and state
123    /// updates if we have a continuous sequence
124    fn add_proof(&mut self, sequence: u64, update: SparseTrieUpdate) -> Vec<SparseTrieUpdate> {
125        if sequence >= self.next_to_deliver {
126            self.pending_proofs.insert(sequence, update);
127        }
128
129        // return early if we don't have the next expected proof
130        if !self.pending_proofs.contains_key(&self.next_to_deliver) {
131            return Vec::new()
132        }
133
134        let mut consecutive_proofs = Vec::with_capacity(self.pending_proofs.len());
135        let mut current_sequence = self.next_to_deliver;
136
137        // keep collecting proofs and state updates as long as we have consecutive sequence numbers
138        while let Some(pending) = self.pending_proofs.remove(&current_sequence) {
139            consecutive_proofs.push(pending);
140            current_sequence += 1;
141
142            // if we don't have the next number, stop collecting
143            if !self.pending_proofs.contains_key(&current_sequence) {
144                break;
145            }
146        }
147
148        self.next_to_deliver += consecutive_proofs.len() as u64;
149
150        consecutive_proofs
151    }
152
153    /// Returns true if we still have pending proofs
154    pub(crate) fn has_pending(&self) -> bool {
155        !self.pending_proofs.is_empty()
156    }
157}
158
159/// A wrapper for the sender that signals completion when dropped.
160///
161/// This type is intended to be used in combination with the evm executor statehook.
162/// This should trigger once the block has been executed (after) the last state update has been
163/// sent. This triggers the exit condition of the multi proof task.
164#[derive(Deref, Debug)]
165pub(super) struct StateHookSender(CrossbeamSender<MultiProofMessage>);
166
167impl StateHookSender {
168    pub(crate) const fn new(inner: CrossbeamSender<MultiProofMessage>) -> Self {
169        Self(inner)
170    }
171}
172
173impl Drop for StateHookSender {
174    fn drop(&mut self) {
175        // Send completion signal when the sender is dropped
176        let _ = self.0.send(MultiProofMessage::FinishedStateUpdates);
177    }
178}
179
180pub(crate) fn evm_state_to_hashed_post_state(update: EvmState) -> HashedPostState {
181    let mut hashed_state = HashedPostState::with_capacity(update.len());
182
183    for (address, account) in update {
184        if account.is_touched() {
185            let hashed_address = keccak256(address);
186            trace!(target: "engine::tree::payload_processor::multiproof", ?address, ?hashed_address, "Adding account to state update");
187
188            let destroyed = account.is_selfdestructed();
189            let info = if destroyed { None } else { Some(account.info.into()) };
190            hashed_state.accounts.insert(hashed_address, info);
191
192            let mut changed_storage_iter = account
193                .storage
194                .into_iter()
195                .filter(|(_slot, value)| value.is_changed())
196                .map(|(slot, value)| (keccak256(B256::from(slot)), value.present_value))
197                .peekable();
198
199            if destroyed {
200                hashed_state.storages.insert(hashed_address, HashedStorage::new(true));
201            } else if changed_storage_iter.peek().is_some() {
202                hashed_state
203                    .storages
204                    .insert(hashed_address, HashedStorage::from_iter(false, changed_storage_iter));
205            }
206        }
207    }
208
209    hashed_state
210}
211
212/// A pending multiproof task, either [`StorageMultiproofInput`] or [`MultiproofInput`].
213#[derive(Debug)]
214enum PendingMultiproofTask {
215    /// A storage multiproof task input.
216    Storage(StorageMultiproofInput),
217    /// A regular multiproof task input.
218    Regular(MultiproofInput),
219}
220
221impl PendingMultiproofTask {
222    /// Returns the proof sequence number of the task.
223    const fn proof_sequence_number(&self) -> u64 {
224        match self {
225            Self::Storage(input) => input.proof_sequence_number,
226            Self::Regular(input) => input.proof_sequence_number,
227        }
228    }
229
230    /// Returns whether or not the proof targets are empty.
231    fn proof_targets_is_empty(&self) -> bool {
232        match self {
233            Self::Storage(input) => input.proof_targets.is_empty(),
234            Self::Regular(input) => input.proof_targets.is_empty(),
235        }
236    }
237
238    /// Destroys the input and sends a [`MultiProofMessage::EmptyProof`] message to the sender.
239    fn send_empty_proof(self) {
240        match self {
241            Self::Storage(input) => input.send_empty_proof(),
242            Self::Regular(input) => input.send_empty_proof(),
243        }
244    }
245}
246
247impl From<StorageMultiproofInput> for PendingMultiproofTask {
248    fn from(input: StorageMultiproofInput) -> Self {
249        Self::Storage(input)
250    }
251}
252
253impl From<MultiproofInput> for PendingMultiproofTask {
254    fn from(input: MultiproofInput) -> Self {
255        Self::Regular(input)
256    }
257}
258
259/// Input parameters for dispatching a dedicated storage multiproof calculation.
260#[derive(Debug)]
261struct StorageMultiproofInput {
262    hashed_state_update: HashedPostState,
263    hashed_address: B256,
264    proof_targets: B256Set,
265    proof_sequence_number: u64,
266    state_root_message_sender: CrossbeamSender<MultiProofMessage>,
267    multi_added_removed_keys: Arc<MultiAddedRemovedKeys>,
268}
269
270impl StorageMultiproofInput {
271    /// Destroys the input and sends a [`MultiProofMessage::EmptyProof`] message to the sender.
272    fn send_empty_proof(self) {
273        let _ = self.state_root_message_sender.send(MultiProofMessage::EmptyProof {
274            sequence_number: self.proof_sequence_number,
275            state: self.hashed_state_update,
276        });
277    }
278}
279
280/// Input parameters for dispatching a multiproof calculation.
281#[derive(Debug)]
282struct MultiproofInput {
283    source: Option<StateChangeSource>,
284    hashed_state_update: HashedPostState,
285    proof_targets: MultiProofTargets,
286    proof_sequence_number: u64,
287    state_root_message_sender: CrossbeamSender<MultiProofMessage>,
288    multi_added_removed_keys: Option<Arc<MultiAddedRemovedKeys>>,
289}
290
291impl MultiproofInput {
292    /// Destroys the input and sends a [`MultiProofMessage::EmptyProof`] message to the sender.
293    fn send_empty_proof(self) {
294        let _ = self.state_root_message_sender.send(MultiProofMessage::EmptyProof {
295            sequence_number: self.proof_sequence_number,
296            state: self.hashed_state_update,
297        });
298    }
299}
300
301/// Coordinates multiproof dispatch between `MultiProofTask` and the parallel trie workers.
302///
303/// # Flow
304/// 1. `MultiProofTask` asks the manager to dispatch either storage or account proof work.
305/// 2. The manager builds the request, clones `proof_result_tx`, and hands everything to
306///    [`ProofWorkerHandle`].
307/// 3. A worker finishes the proof and sends a [`ProofResultMessage`] through the channel included
308///    in the job.
309/// 4. `MultiProofTask` consumes the message from the same channel and sequences it with
310///    `ProofSequencer`.
311#[derive(Debug)]
312pub struct MultiproofManager {
313    /// Handle to the proof worker pools (storage and account).
314    proof_worker_handle: ProofWorkerHandle,
315    /// Cached storage proof roots for missed leaves; this maps
316    /// hashed (missed) addresses to their storage proof roots.
317    ///
318    /// It is important to cache these. Otherwise, a common account
319    /// (popular ERC-20, etc.) having missed leaves in its path would
320    /// repeatedly calculate these proofs per interacting transaction
321    /// (same account different slots).
322    ///
323    /// This also works well with chunking multiproofs, which may break
324    /// a big account change into different chunks, which may repeatedly
325    /// revisit missed leaves.
326    missed_leaves_storage_roots: Arc<DashMap<B256, B256>>,
327    /// Channel sender cloned into each dispatched job so workers can send back the
328    /// `ProofResultMessage`.
329    proof_result_tx: CrossbeamSender<ProofResultMessage>,
330    /// Metrics
331    metrics: MultiProofTaskMetrics,
332}
333
334impl MultiproofManager {
335    /// Creates a new [`MultiproofManager`].
336    fn new(
337        metrics: MultiProofTaskMetrics,
338        proof_worker_handle: ProofWorkerHandle,
339        proof_result_tx: CrossbeamSender<ProofResultMessage>,
340    ) -> Self {
341        // Initialize the max worker gauges with the worker pool sizes
342        metrics.max_storage_workers.set(proof_worker_handle.total_storage_workers() as f64);
343        metrics.max_account_workers.set(proof_worker_handle.total_account_workers() as f64);
344
345        Self {
346            metrics,
347            proof_worker_handle,
348            missed_leaves_storage_roots: Default::default(),
349            proof_result_tx,
350        }
351    }
352
353    /// Dispatches a new multiproof calculation to worker pools.
354    fn dispatch(&self, input: PendingMultiproofTask) {
355        // If there are no proof targets, we can just send an empty multiproof back immediately
356        if input.proof_targets_is_empty() {
357            trace!(
358                sequence_number = input.proof_sequence_number(),
359                "No proof targets, sending empty multiproof back immediately"
360            );
361            input.send_empty_proof();
362            return;
363        }
364
365        match input {
366            PendingMultiproofTask::Storage(storage_input) => {
367                self.dispatch_storage_proof(storage_input);
368            }
369            PendingMultiproofTask::Regular(multiproof_input) => {
370                self.dispatch_multiproof(multiproof_input);
371            }
372        }
373    }
374
375    /// Dispatches a single storage proof calculation to worker pool.
376    fn dispatch_storage_proof(&self, storage_multiproof_input: StorageMultiproofInput) {
377        let StorageMultiproofInput {
378            hashed_state_update,
379            hashed_address,
380            proof_targets,
381            proof_sequence_number,
382            multi_added_removed_keys,
383            state_root_message_sender: _,
384        } = storage_multiproof_input;
385
386        let storage_targets = proof_targets.len();
387
388        trace!(
389            target: "engine::tree::payload_processor::multiproof",
390            proof_sequence_number,
391            ?proof_targets,
392            storage_targets,
393            "Dispatching storage proof to workers"
394        );
395
396        let start = Instant::now();
397
398        // Create prefix set from targets
399        let prefix_set = reth_trie::prefix_set::PrefixSetMut::from(
400            proof_targets.iter().map(reth_trie::Nibbles::unpack),
401        );
402        let prefix_set = prefix_set.freeze();
403
404        // Build computation input (data only)
405        let input = StorageProofInput::new(
406            hashed_address,
407            prefix_set,
408            proof_targets,
409            true, // with_branch_node_masks
410            Some(multi_added_removed_keys),
411        );
412
413        // Dispatch to storage worker
414        if let Err(e) = self.proof_worker_handle.dispatch_storage_proof(
415            input,
416            ProofResultContext::new(
417                self.proof_result_tx.clone(),
418                proof_sequence_number,
419                hashed_state_update,
420                start,
421            ),
422        ) {
423            error!(target: "engine::tree::payload_processor::multiproof", ?e, "Failed to dispatch storage proof");
424            return;
425        }
426
427        self.metrics
428            .active_storage_workers_histogram
429            .record(self.proof_worker_handle.active_storage_workers() as f64);
430        self.metrics
431            .active_account_workers_histogram
432            .record(self.proof_worker_handle.active_account_workers() as f64);
433        self.metrics
434            .pending_storage_multiproofs_histogram
435            .record(self.proof_worker_handle.pending_storage_tasks() as f64);
436        self.metrics
437            .pending_account_multiproofs_histogram
438            .record(self.proof_worker_handle.pending_account_tasks() as f64);
439    }
440
441    /// Signals that a multiproof calculation has finished.
442    fn on_calculation_complete(&self) {
443        self.metrics
444            .active_storage_workers_histogram
445            .record(self.proof_worker_handle.active_storage_workers() as f64);
446        self.metrics
447            .active_account_workers_histogram
448            .record(self.proof_worker_handle.active_account_workers() as f64);
449        self.metrics
450            .pending_storage_multiproofs_histogram
451            .record(self.proof_worker_handle.pending_storage_tasks() as f64);
452        self.metrics
453            .pending_account_multiproofs_histogram
454            .record(self.proof_worker_handle.pending_account_tasks() as f64);
455    }
456
457    /// Dispatches a single multiproof calculation to worker pool.
458    fn dispatch_multiproof(&self, multiproof_input: MultiproofInput) {
459        let MultiproofInput {
460            source,
461            hashed_state_update,
462            proof_targets,
463            proof_sequence_number,
464            state_root_message_sender: _,
465            multi_added_removed_keys,
466        } = multiproof_input;
467
468        let missed_leaves_storage_roots = self.missed_leaves_storage_roots.clone();
469        let account_targets = proof_targets.len();
470        let storage_targets = proof_targets.values().map(|slots| slots.len()).sum::<usize>();
471
472        trace!(
473            target: "engine::tree::payload_processor::multiproof",
474            proof_sequence_number,
475            ?proof_targets,
476            account_targets,
477            storage_targets,
478            ?source,
479            "Dispatching multiproof to workers"
480        );
481
482        let start = Instant::now();
483
484        // Extend prefix sets with targets
485        let frozen_prefix_sets =
486            ParallelProof::extend_prefix_sets_with_targets(&Default::default(), &proof_targets);
487
488        // Dispatch account multiproof to worker pool with result sender
489        let input = AccountMultiproofInput {
490            targets: proof_targets,
491            prefix_sets: frozen_prefix_sets,
492            collect_branch_node_masks: true,
493            multi_added_removed_keys,
494            missed_leaves_storage_roots,
495            // Workers will send ProofResultMessage directly to proof_result_rx
496            proof_result_sender: ProofResultContext::new(
497                self.proof_result_tx.clone(),
498                proof_sequence_number,
499                hashed_state_update,
500                start,
501            ),
502        };
503
504        if let Err(e) = self.proof_worker_handle.dispatch_account_multiproof(input) {
505            error!(target: "engine::tree::payload_processor::multiproof", ?e, "Failed to dispatch account multiproof");
506            return;
507        }
508
509        self.metrics
510            .active_storage_workers_histogram
511            .record(self.proof_worker_handle.active_storage_workers() as f64);
512        self.metrics
513            .active_account_workers_histogram
514            .record(self.proof_worker_handle.active_account_workers() as f64);
515        self.metrics
516            .pending_storage_multiproofs_histogram
517            .record(self.proof_worker_handle.pending_storage_tasks() as f64);
518        self.metrics
519            .pending_account_multiproofs_histogram
520            .record(self.proof_worker_handle.pending_account_tasks() as f64);
521    }
522}
523
524#[derive(Metrics, Clone)]
525#[metrics(scope = "tree.root")]
526pub(crate) struct MultiProofTaskMetrics {
527    /// Histogram of active storage workers processing proofs.
528    pub active_storage_workers_histogram: Histogram,
529    /// Histogram of active account workers processing proofs.
530    pub active_account_workers_histogram: Histogram,
531    /// Gauge for the maximum number of storage workers in the pool.
532    pub max_storage_workers: Gauge,
533    /// Gauge for the maximum number of account workers in the pool.
534    pub max_account_workers: Gauge,
535    /// Histogram of pending storage multiproofs in the queue.
536    pub pending_storage_multiproofs_histogram: Histogram,
537    /// Histogram of pending account multiproofs in the queue.
538    pub pending_account_multiproofs_histogram: Histogram,
539
540    /// Histogram of the number of prefetch proof target accounts.
541    pub prefetch_proof_targets_accounts_histogram: Histogram,
542    /// Histogram of the number of prefetch proof target storages.
543    pub prefetch_proof_targets_storages_histogram: Histogram,
544    /// Histogram of the number of prefetch proof target chunks.
545    pub prefetch_proof_chunks_histogram: Histogram,
546
547    /// Histogram of the number of state update proof target accounts.
548    pub state_update_proof_targets_accounts_histogram: Histogram,
549    /// Histogram of the number of state update proof target storages.
550    pub state_update_proof_targets_storages_histogram: Histogram,
551    /// Histogram of the number of state update proof target chunks.
552    pub state_update_proof_chunks_histogram: Histogram,
553
554    /// Histogram of prefetch proof batch sizes (number of messages merged).
555    pub prefetch_batch_size_histogram: Histogram,
556    /// Histogram of state update batch sizes (number of messages merged).
557    pub state_update_batch_size_histogram: Histogram,
558
559    /// Histogram of proof calculation durations.
560    pub proof_calculation_duration_histogram: Histogram,
561
562    /// Histogram of sparse trie update durations.
563    pub sparse_trie_update_duration_histogram: Histogram,
564    /// Histogram of sparse trie final update durations.
565    pub sparse_trie_final_update_duration_histogram: Histogram,
566    /// Histogram of sparse trie total durations.
567    pub sparse_trie_total_duration_histogram: Histogram,
568
569    /// Histogram of state updates received.
570    pub state_updates_received_histogram: Histogram,
571    /// Histogram of proofs processed.
572    pub proofs_processed_histogram: Histogram,
573    /// Histogram of total time spent in the multiproof task.
574    pub multiproof_task_total_duration_histogram: Histogram,
575    /// Total time spent waiting for the first state update or prefetch request.
576    pub first_update_wait_time_histogram: Histogram,
577    /// Total time spent waiting for the last proof result.
578    pub last_proof_wait_time_histogram: Histogram,
579}
580
581/// Standalone task that receives a transaction state stream and updates relevant
582/// data structures to calculate state root.
583///
584/// ## Architecture: Dual-Channel Multiproof System
585///
586/// This task orchestrates parallel proof computation using a dual-channel architecture that
587/// separates control messages from proof computation results:
588///
589/// ```text
590/// ┌─────────────────────────────────────────────────────────────────┐
591/// │                        MultiProofTask                            │
592/// │                  Event Loop (crossbeam::select!)                 │
593/// └──┬──────────────────────────────────────────────────────────▲───┘
594///    │                                                           │
595///    │ (1) Send proof request                                   │
596///    │     via tx (control channel)                             │
597///    │                                                           │
598///    ▼                                                           │
599/// ┌──────────────────────────────────────────────────────────────┐ │
600/// │             MultiproofManager                                │ │
601/// │  - Deduplicates against fetched_proof_targets                │ │
602/// │  - Routes to appropriate worker pool                         │ │
603/// └──┬───────────────────────────────────────────────────────────┘ │
604///    │                                                             │
605///    │ (2) Dispatch to workers                                    │
606///    │     OR send EmptyProof (fast path)                         │
607///    ▼                                                             │
608/// ┌──────────────────────────────────────────────────────────────┐ │
609/// │              ProofWorkerHandle                                │ │
610/// │  ┌─────────────────────┐   ┌────────────────────────┐        │ │
611/// │  │ Storage Worker Pool │   │ Account Worker Pool     │        │ │
612/// │  │ (spawn_blocking)    │   │ (spawn_blocking)        │        │ │
613/// │  └─────────────────────┘   └────────────────────────┘        │ │
614/// └──┬───────────────────────────────────────────────────────────┘ │
615///    │                                                             │
616///    │ (3) Compute proofs in parallel                             │
617///    │     Send results back                                      │
618///    │                                                             │
619///    ▼                                                             │
620/// ┌──────────────────────────────────────────────────────────────┐ │
621/// │  proof_result_tx (crossbeam unbounded channel)                │ │
622/// │    → ProofResultMessage { multiproof, sequence_number, ... }  │ │
623/// └──────────────────────────────────────────────────────────────┘ │
624///                                                                   │
625///   (4) Receive via crossbeam::select! on two channels: ───────────┘
626///       - rx: Control messages (PrefetchProofs, StateUpdate,
627///             EmptyProof, FinishedStateUpdates)
628///       - proof_result_rx: Computed proof results from workers
629/// ```
630///
631/// ## Component Responsibilities
632///
633/// - **[`MultiProofTask`]**: Event loop coordinator
634///   - Receives state updates from transaction execution
635///   - Deduplicates proof targets against already-fetched proofs
636///   - Sequences proofs to maintain transaction ordering
637///   - Feeds sequenced updates to sparse trie task
638///
639/// - **[`MultiproofManager`]**: Calculation orchestrator
640///   - Decides between fast path ([`EmptyProof`]) and worker dispatch
641///   - Routes storage-only vs full multiproofs to appropriate workers
642///   - Records metrics for monitoring
643///
644/// - **[`ProofWorkerHandle`]**: Worker pool manager
645///   - Maintains separate pools for storage and account proofs
646///   - Dispatches work to blocking threads (CPU-intensive)
647///   - Sends results directly via `proof_result_tx` (bypasses control channel)
648///
649/// [`EmptyProof`]: MultiProofMessage::EmptyProof
650/// [`ProofWorkerHandle`]: reth_trie_parallel::proof_task::ProofWorkerHandle
651///
652/// ## Dual-Channel Design Rationale
653///
654/// The system uses two separate crossbeam channels:
655///
656/// 1. **Control Channel (`tx`/`rx`)**: For orchestration messages
657///    - `PrefetchProofs`: Pre-fetch proofs before execution
658///    - `StateUpdate`: New transaction execution results
659///    - `EmptyProof`: Fast path when all targets already fetched
660///    - `FinishedStateUpdates`: Signal to drain pending work
661///
662/// 2. **Proof Result Channel (`proof_result_tx`/`proof_result_rx`)**: For worker results
663///    - `ProofResultMessage`: Computed multiproofs from worker pools
664///    - Direct path from workers to event loop (no intermediate hops)
665///    - Keeps control messages separate from high-throughput proof data
666///
667/// This separation enables:
668/// - **Non-blocking control**: Control messages never wait behind large proof data
669/// - **Backpressure management**: Each channel can apply different policies
670/// - **Clear ownership**: Workers only need proof result sender, not control channel
671///
672/// ## Initialization and Lifecycle
673///
674/// The task initializes a blinded sparse trie and subscribes to transaction state streams.
675/// As it receives transaction execution results, it fetches proofs for relevant accounts
676/// from the database and reveals them to the tree, then updates relevant leaves according
677/// to transaction results. This feeds updates to the sparse trie task.
678///
679/// See the `run()` method documentation for detailed lifecycle flow.
680#[derive(Debug)]
681pub(super) struct MultiProofTask {
682    /// The size of proof targets chunk to spawn in one calculation.
683    /// If None, chunking is disabled and all targets are processed in a single proof.
684    chunk_size: Option<usize>,
685    /// Receiver for state root related messages (prefetch, state updates, finish signal).
686    rx: CrossbeamReceiver<MultiProofMessage>,
687    /// Sender for state root related messages.
688    tx: CrossbeamSender<MultiProofMessage>,
689    /// Receiver for proof results directly from workers.
690    proof_result_rx: CrossbeamReceiver<ProofResultMessage>,
691    /// Sender for state updates emitted by this type.
692    to_sparse_trie: std::sync::mpsc::Sender<SparseTrieUpdate>,
693    /// Proof targets that have been already fetched.
694    fetched_proof_targets: MultiProofTargets,
695    /// Tracks keys which have been added and removed throughout the entire block.
696    multi_added_removed_keys: MultiAddedRemovedKeys,
697    /// Proof sequencing handler.
698    proof_sequencer: ProofSequencer,
699    /// Manages calculation of multiproofs.
700    multiproof_manager: MultiproofManager,
701    /// multi proof task metrics
702    metrics: MultiProofTaskMetrics,
703    /// If this number is exceeded and chunking is enabled, then this will override whether or not
704    /// there are any active workers and force chunking across workers. This is to prevent tasks
705    /// which are very long from hitting a single worker.
706    max_targets_for_chunking: usize,
707}
708
709impl MultiProofTask {
710    /// Creates a multiproof task with separate channels: control on `tx`/`rx`, proof results on
711    /// `proof_result_rx`.
712    pub(super) fn new(
713        proof_worker_handle: ProofWorkerHandle,
714        to_sparse_trie: std::sync::mpsc::Sender<SparseTrieUpdate>,
715        chunk_size: Option<usize>,
716    ) -> Self {
717        let (tx, rx) = unbounded();
718        let (proof_result_tx, proof_result_rx) = unbounded();
719        let metrics = MultiProofTaskMetrics::default();
720
721        Self {
722            chunk_size,
723            rx,
724            tx,
725            proof_result_rx,
726            to_sparse_trie,
727            fetched_proof_targets: Default::default(),
728            multi_added_removed_keys: MultiAddedRemovedKeys::new(),
729            proof_sequencer: ProofSequencer::default(),
730            multiproof_manager: MultiproofManager::new(
731                metrics.clone(),
732                proof_worker_handle,
733                proof_result_tx,
734            ),
735            metrics,
736            max_targets_for_chunking: DEFAULT_MAX_TARGETS_FOR_CHUNKING,
737        }
738    }
739
740    /// Returns a sender that can be used to send arbitrary [`MultiProofMessage`]s to this task.
741    pub(super) fn state_root_message_sender(&self) -> CrossbeamSender<MultiProofMessage> {
742        self.tx.clone()
743    }
744
745    /// Handles request for proof prefetch.
746    ///
747    /// Returns how many multiproof tasks were dispatched for the prefetch request.
748    #[instrument(
749        level = "debug",
750        target = "engine::tree::payload_processor::multiproof",
751        skip_all,
752        fields(accounts = targets.len(), chunks = 0)
753    )]
754    fn on_prefetch_proof(&mut self, targets: MultiProofTargets) -> u64 {
755        let proof_targets = self.get_prefetch_proof_targets(targets);
756        self.fetched_proof_targets.extend_ref(&proof_targets);
757
758        // Make sure all target accounts have an `AddedRemovedKeySet` in the
759        // [`MultiAddedRemovedKeys`]. Even if there are not any known removed keys for the account,
760        // we still want to optimistically fetch extension children for the leaf addition case.
761        self.multi_added_removed_keys.touch_accounts(proof_targets.keys().copied());
762
763        // Clone+Arc MultiAddedRemovedKeys for sharing with the dispatched multiproof tasks
764        let multi_added_removed_keys = Arc::new(self.multi_added_removed_keys.clone());
765
766        self.metrics.prefetch_proof_targets_accounts_histogram.record(proof_targets.len() as f64);
767        self.metrics
768            .prefetch_proof_targets_storages_histogram
769            .record(proof_targets.values().map(|slots| slots.len()).sum::<usize>() as f64);
770
771        let chunking_len = proof_targets.chunking_length();
772        let available_account_workers =
773            self.multiproof_manager.proof_worker_handle.available_account_workers();
774        let available_storage_workers =
775            self.multiproof_manager.proof_worker_handle.available_storage_workers();
776        let num_chunks = dispatch_with_chunking(
777            proof_targets,
778            chunking_len,
779            self.chunk_size,
780            self.max_targets_for_chunking,
781            available_account_workers,
782            available_storage_workers,
783            MultiProofTargets::chunks,
784            |proof_targets| {
785                self.multiproof_manager.dispatch(
786                    MultiproofInput {
787                        source: None,
788                        hashed_state_update: Default::default(),
789                        proof_targets,
790                        proof_sequence_number: self.proof_sequencer.next_sequence(),
791                        state_root_message_sender: self.tx.clone(),
792                        multi_added_removed_keys: Some(multi_added_removed_keys.clone()),
793                    }
794                    .into(),
795                );
796            },
797        );
798        self.metrics.prefetch_proof_chunks_histogram.record(num_chunks as f64);
799
800        num_chunks as u64
801    }
802
803    // Returns true if all state updates finished and all proofs processed.
804    fn is_done(
805        &self,
806        proofs_processed: u64,
807        state_update_proofs_requested: u64,
808        prefetch_proofs_requested: u64,
809        updates_finished: bool,
810    ) -> bool {
811        let all_proofs_processed =
812            proofs_processed >= state_update_proofs_requested + prefetch_proofs_requested;
813        let no_pending = !self.proof_sequencer.has_pending();
814        trace!(
815            target: "engine::tree::payload_processor::multiproof",
816            proofs_processed,
817            state_update_proofs_requested,
818            prefetch_proofs_requested,
819            no_pending,
820            updates_finished,
821            "Checking end condition"
822        );
823        all_proofs_processed && no_pending && updates_finished
824    }
825
826    /// Calls `get_proof_targets` with existing proof targets for prefetching.
827    fn get_prefetch_proof_targets(&self, mut targets: MultiProofTargets) -> MultiProofTargets {
828        // Here we want to filter out any targets that are already fetched
829        //
830        // This means we need to remove any storage slots that have already been fetched
831        let mut duplicates = 0;
832
833        // First remove all storage targets that are subsets of already fetched storage slots
834        targets.retain(|hashed_address, target_storage| {
835            let keep = self
836                .fetched_proof_targets
837                .get(hashed_address)
838                // do NOT remove if None, because that means the account has not been fetched yet
839                .is_none_or(|fetched_storage| {
840                    // remove if a subset
841                    !target_storage.is_subset(fetched_storage)
842                });
843
844            if !keep {
845                duplicates += target_storage.len();
846            }
847
848            keep
849        });
850
851        // For all non-subset remaining targets, we have to calculate the difference
852        for (hashed_address, target_storage) in targets.deref_mut() {
853            let Some(fetched_storage) = self.fetched_proof_targets.get(hashed_address) else {
854                // this means the account has not been fetched yet, so we must fetch everything
855                // associated with this account
856                continue;
857            };
858
859            let prev_target_storage_len = target_storage.len();
860
861            // keep only the storage slots that have not been fetched yet
862            //
863            // we already removed subsets, so this should only remove duplicates
864            target_storage.retain(|slot| !fetched_storage.contains(slot));
865
866            duplicates += prev_target_storage_len - target_storage.len();
867        }
868
869        if duplicates > 0 {
870            trace!(target: "engine::tree::payload_processor::multiproof", duplicates, "Removed duplicate prefetch proof targets");
871        }
872
873        targets
874    }
875
876    /// Handles state updates.
877    ///
878    /// Returns how many proof dispatches were spawned (including an `EmptyProof` for already
879    /// fetched targets).
880    #[instrument(
881        level = "debug",
882        target = "engine::tree::payload_processor::multiproof",
883        skip(self, update),
884        fields(accounts = update.len(), chunks = 0)
885    )]
886    fn on_state_update(&mut self, source: StateChangeSource, update: EvmState) -> u64 {
887        let hashed_state_update = evm_state_to_hashed_post_state(update);
888
889        // Update removed keys based on the state update.
890        self.multi_added_removed_keys.update_with_state(&hashed_state_update);
891
892        // Split the state update into already fetched and not fetched according to the proof
893        // targets.
894        let (fetched_state_update, not_fetched_state_update) = hashed_state_update
895            .partition_by_targets(&self.fetched_proof_targets, &self.multi_added_removed_keys);
896
897        let mut state_updates = 0;
898        // If there are any accounts or storage slots that we already fetched the proofs for,
899        // send them immediately, as they don't require dispatching any additional multiproofs.
900        if !fetched_state_update.is_empty() {
901            let _ = self.tx.send(MultiProofMessage::EmptyProof {
902                sequence_number: self.proof_sequencer.next_sequence(),
903                state: fetched_state_update,
904            });
905            state_updates += 1;
906        }
907
908        // Clone+Arc MultiAddedRemovedKeys for sharing with the dispatched multiproof tasks
909        let multi_added_removed_keys = Arc::new(self.multi_added_removed_keys.clone());
910
911        let chunking_len = not_fetched_state_update.chunking_length();
912        let mut spawned_proof_targets = MultiProofTargets::default();
913        let available_account_workers =
914            self.multiproof_manager.proof_worker_handle.available_account_workers();
915        let available_storage_workers =
916            self.multiproof_manager.proof_worker_handle.available_storage_workers();
917        let num_chunks = dispatch_with_chunking(
918            not_fetched_state_update,
919            chunking_len,
920            self.chunk_size,
921            self.max_targets_for_chunking,
922            available_account_workers,
923            available_storage_workers,
924            HashedPostState::chunks,
925            |hashed_state_update| {
926                let proof_targets = get_proof_targets(
927                    &hashed_state_update,
928                    &self.fetched_proof_targets,
929                    &multi_added_removed_keys,
930                );
931                spawned_proof_targets.extend_ref(&proof_targets);
932
933                self.multiproof_manager.dispatch(
934                    MultiproofInput {
935                        source: Some(source),
936                        hashed_state_update,
937                        proof_targets,
938                        proof_sequence_number: self.proof_sequencer.next_sequence(),
939                        state_root_message_sender: self.tx.clone(),
940                        multi_added_removed_keys: Some(multi_added_removed_keys.clone()),
941                    }
942                    .into(),
943                );
944            },
945        );
946        self.metrics
947            .state_update_proof_targets_accounts_histogram
948            .record(spawned_proof_targets.len() as f64);
949        self.metrics
950            .state_update_proof_targets_storages_histogram
951            .record(spawned_proof_targets.values().map(|slots| slots.len()).sum::<usize>() as f64);
952        self.metrics.state_update_proof_chunks_histogram.record(num_chunks as f64);
953
954        self.fetched_proof_targets.extend(spawned_proof_targets);
955
956        state_updates + num_chunks as u64
957    }
958
959    /// Handler for new proof calculated, aggregates all the existing sequential proofs.
960    fn on_proof(
961        &mut self,
962        sequence_number: u64,
963        update: SparseTrieUpdate,
964    ) -> Option<SparseTrieUpdate> {
965        let ready_proofs = self.proof_sequencer.add_proof(sequence_number, update);
966
967        ready_proofs
968            .into_iter()
969            // Merge all ready proofs and state updates
970            .reduce(|mut acc_update, update| {
971                acc_update.extend(update);
972                acc_update
973            })
974            // Return None if the resulting proof is empty
975            .filter(|proof| !proof.is_empty())
976    }
977
978    /// Processes a multiproof message, batching consecutive same-type messages.
979    ///
980    /// Drains queued messages of the same type and merges them into one batch before processing,
981    /// storing one pending message (different type or over-cap) to handle on the next iteration.
982    /// This preserves ordering without requeuing onto the channel.
983    ///
984    /// Returns `true` if done, `false` to continue.
985    fn process_multiproof_message(
986        &mut self,
987        msg: MultiProofMessage,
988        ctx: &mut MultiproofBatchCtx,
989        batch_metrics: &mut MultiproofBatchMetrics,
990    ) -> bool {
991        match msg {
992            // Prefetch proofs: batch consecutive prefetch requests up to target/message limits
993            MultiProofMessage::PrefetchProofs(targets) => {
994                trace!(target: "engine::tree::payload_processor::multiproof", "processing MultiProofMessage::PrefetchProofs");
995
996                if ctx.first_update_time.is_none() {
997                    self.metrics
998                        .first_update_wait_time_histogram
999                        .record(ctx.start.elapsed().as_secs_f64());
1000                    ctx.first_update_time = Some(Instant::now());
1001                    debug!(target: "engine::tree::payload_processor::multiproof", "Started state root calculation");
1002                }
1003
1004                let mut accumulated_count = targets.chunking_length();
1005                ctx.accumulated_prefetch_targets.clear();
1006                ctx.accumulated_prefetch_targets.push(targets);
1007
1008                // Batch consecutive prefetch messages up to limits.
1009                while accumulated_count < PREFETCH_MAX_BATCH_TARGETS &&
1010                    ctx.accumulated_prefetch_targets.len() < PREFETCH_MAX_BATCH_MESSAGES
1011                {
1012                    match self.rx.try_recv() {
1013                        Ok(MultiProofMessage::PrefetchProofs(next_targets)) => {
1014                            let next_count = next_targets.chunking_length();
1015                            if accumulated_count + next_count > PREFETCH_MAX_BATCH_TARGETS {
1016                                ctx.pending_msg =
1017                                    Some(MultiProofMessage::PrefetchProofs(next_targets));
1018                                break;
1019                            }
1020                            accumulated_count += next_count;
1021                            ctx.accumulated_prefetch_targets.push(next_targets);
1022                        }
1023                        Ok(other_msg) => {
1024                            ctx.pending_msg = Some(other_msg);
1025                            break;
1026                        }
1027                        Err(_) => break,
1028                    }
1029                }
1030
1031                // Process all accumulated messages in a single batch
1032                let num_batched = ctx.accumulated_prefetch_targets.len();
1033                self.metrics.prefetch_batch_size_histogram.record(num_batched as f64);
1034
1035                // Merge all accumulated prefetch targets into a single dispatch payload.
1036                // Use drain to preserve the buffer allocation.
1037                let mut accumulated_iter = ctx.accumulated_prefetch_targets.drain(..);
1038                let mut merged_targets =
1039                    accumulated_iter.next().expect("prefetch batch always has at least one entry");
1040                for next_targets in accumulated_iter {
1041                    merged_targets.extend(next_targets);
1042                }
1043
1044                let account_targets = merged_targets.len();
1045                let storage_targets =
1046                    merged_targets.values().map(|slots| slots.len()).sum::<usize>();
1047                batch_metrics.prefetch_proofs_requested += self.on_prefetch_proof(merged_targets);
1048                trace!(
1049                    target: "engine::tree::payload_processor::multiproof",
1050                    account_targets,
1051                    storage_targets,
1052                    prefetch_proofs_requested = batch_metrics.prefetch_proofs_requested,
1053                    num_batched,
1054                    "Dispatched prefetch batch"
1055                );
1056
1057                false
1058            }
1059            // State update: batch consecutive updates from the same source
1060            MultiProofMessage::StateUpdate(source, update) => {
1061                trace!(target: "engine::tree::payload_processor::multiproof", "processing MultiProofMessage::StateUpdate");
1062
1063                if ctx.first_update_time.is_none() {
1064                    self.metrics
1065                        .first_update_wait_time_histogram
1066                        .record(ctx.start.elapsed().as_secs_f64());
1067                    ctx.first_update_time = Some(Instant::now());
1068                    debug!(target: "engine::tree::payload_processor::multiproof", "Started state root calculation");
1069                }
1070
1071                // Accumulate messages including the first one; reuse buffer to avoid allocations.
1072                let mut accumulated_targets = estimate_evm_state_targets(&update);
1073                ctx.accumulated_state_updates.clear();
1074                ctx.accumulated_state_updates.push((source, update));
1075
1076                // Batch consecutive state update messages up to target limit.
1077                while accumulated_targets < STATE_UPDATE_MAX_BATCH_TARGETS {
1078                    match self.rx.try_recv() {
1079                        Ok(MultiProofMessage::StateUpdate(next_source, next_update)) => {
1080                            let (batch_source, batch_update) = &ctx.accumulated_state_updates[0];
1081                            if !can_batch_state_update(
1082                                *batch_source,
1083                                batch_update,
1084                                next_source,
1085                                &next_update,
1086                            ) {
1087                                ctx.pending_msg =
1088                                    Some(MultiProofMessage::StateUpdate(next_source, next_update));
1089                                break;
1090                            }
1091
1092                            let next_estimate = estimate_evm_state_targets(&next_update);
1093                            // Would exceed batch cap; leave pending to dispatch on next iteration.
1094                            if accumulated_targets + next_estimate > STATE_UPDATE_MAX_BATCH_TARGETS
1095                            {
1096                                ctx.pending_msg =
1097                                    Some(MultiProofMessage::StateUpdate(next_source, next_update));
1098                                break;
1099                            }
1100                            accumulated_targets += next_estimate;
1101                            ctx.accumulated_state_updates.push((next_source, next_update));
1102                        }
1103                        Ok(other_msg) => {
1104                            ctx.pending_msg = Some(other_msg);
1105                            break;
1106                        }
1107                        Err(_) => break,
1108                    }
1109                }
1110
1111                // Process all accumulated messages in a single batch
1112                let num_batched = ctx.accumulated_state_updates.len();
1113                self.metrics.state_update_batch_size_histogram.record(num_batched as f64);
1114
1115                #[cfg(debug_assertions)]
1116                {
1117                    let batch_source = ctx.accumulated_state_updates[0].0;
1118                    let batch_update = &ctx.accumulated_state_updates[0].1;
1119                    debug_assert!(ctx.accumulated_state_updates.iter().all(|(source, update)| {
1120                        can_batch_state_update(batch_source, batch_update, *source, update)
1121                    }));
1122                }
1123
1124                // Merge all accumulated updates into a single EvmState payload.
1125                // Use drain to preserve the buffer allocation.
1126                let mut accumulated_iter = ctx.accumulated_state_updates.drain(..);
1127                let (mut batch_source, mut merged_update) = accumulated_iter
1128                    .next()
1129                    .expect("state update batch always has at least one entry");
1130                for (next_source, next_update) in accumulated_iter {
1131                    batch_source = next_source;
1132                    merged_update.extend(next_update);
1133                }
1134
1135                let batch_len = merged_update.len();
1136                batch_metrics.state_update_proofs_requested +=
1137                    self.on_state_update(batch_source, merged_update);
1138                trace!(
1139                    target: "engine::tree::payload_processor::multiproof",
1140                    ?batch_source,
1141                    len = batch_len,
1142                    state_update_proofs_requested = ?batch_metrics.state_update_proofs_requested,
1143                    num_batched,
1144                    "Dispatched state update batch"
1145                );
1146
1147                false
1148            }
1149            // Signal that no more state updates will arrive
1150            MultiProofMessage::FinishedStateUpdates => {
1151                trace!(target: "engine::tree::payload_processor::multiproof", "processing MultiProofMessage::FinishedStateUpdates");
1152
1153                ctx.updates_finished_time = Some(Instant::now());
1154
1155                if self.is_done(
1156                    batch_metrics.proofs_processed,
1157                    batch_metrics.state_update_proofs_requested,
1158                    batch_metrics.prefetch_proofs_requested,
1159                    ctx.updates_finished(),
1160                ) {
1161                    debug!(
1162                        target: "engine::tree::payload_processor::multiproof",
1163                        "State updates finished and all proofs processed, ending calculation"
1164                    );
1165                    return true;
1166                }
1167                false
1168            }
1169            // Handle proof result with no trie nodes (state unchanged)
1170            MultiProofMessage::EmptyProof { sequence_number, state } => {
1171                trace!(target: "engine::tree::payload_processor::multiproof", "processing MultiProofMessage::EmptyProof");
1172
1173                batch_metrics.proofs_processed += 1;
1174
1175                if let Some(combined_update) = self.on_proof(
1176                    sequence_number,
1177                    SparseTrieUpdate { state, multiproof: Default::default() },
1178                ) {
1179                    let _ = self.to_sparse_trie.send(combined_update);
1180                }
1181
1182                if self.is_done(
1183                    batch_metrics.proofs_processed,
1184                    batch_metrics.state_update_proofs_requested,
1185                    batch_metrics.prefetch_proofs_requested,
1186                    ctx.updates_finished(),
1187                ) {
1188                    debug!(
1189                        target: "engine::tree::payload_processor::multiproof",
1190                        "State updates finished and all proofs processed, ending calculation"
1191                    );
1192                    return true;
1193                }
1194                false
1195            }
1196        }
1197    }
1198
1199    /// Starts the main loop that handles all incoming messages, fetches proofs, applies them to the
1200    /// sparse trie, updates the sparse trie, and eventually returns the state root.
1201    ///
1202    /// The lifecycle is the following:
1203    /// 1. Either [`MultiProofMessage::PrefetchProofs`] or [`MultiProofMessage::StateUpdate`] is
1204    ///    received from the engine.
1205    ///    * For [`MultiProofMessage::StateUpdate`], the state update is hashed with
1206    ///      [`evm_state_to_hashed_post_state`], and then (proof targets)[`MultiProofTargets`] are
1207    ///      extracted with [`get_proof_targets`].
1208    ///    * For both messages, proof targets are deduplicated according to `fetched_proof_targets`,
1209    ///      so that the proofs for accounts and storage slots that were already fetched are not
1210    ///      requested again.
1211    /// 2. Using the proof targets, a new multiproof is calculated using
1212    ///    [`MultiproofManager::dispatch`].
1213    ///    * If the list of proof targets is empty, the [`MultiProofMessage::EmptyProof`] message is
1214    ///      sent back to this task along with the original state update.
1215    ///    * Otherwise, the multiproof is dispatched to worker pools and results are sent directly
1216    ///      to this task via the `proof_result_rx` channel as [`ProofResultMessage`].
1217    /// 3. Either [`MultiProofMessage::EmptyProof`] (via control channel) or [`ProofResultMessage`]
1218    ///    (via proof result channel) is received.
1219    ///    * The multiproof is added to the [`ProofSequencer`].
1220    ///    * If the proof sequencer has a contiguous sequence of multiproofs in the same order as
1221    ///      state updates arrived (i.e. transaction order), such sequence is returned.
1222    /// 4. Once there's a sequence of contiguous multiproofs along with the proof targets and state
1223    ///    updates associated with them, a [`SparseTrieUpdate`] is generated and sent to the sparse
1224    ///    trie task.
1225    /// 5. Steps above are repeated until this task receives a
1226    ///    [`MultiProofMessage::FinishedStateUpdates`].
1227    ///    * Once this message is received, on every [`MultiProofMessage::EmptyProof`] and
1228    ///      [`ProofResultMessage`], we check if all proofs have been processed and if there are any
1229    ///      pending proofs in the proof sequencer left to be revealed.
1230    /// 6. While running, consecutive [`MultiProofMessage::PrefetchProofs`] and
1231    ///    [`MultiProofMessage::StateUpdate`] messages are batched to reduce redundant work; if a
1232    ///    different message type arrives mid-batch or a batch cap is reached, it is held as
1233    ///    `pending_msg` and processed on the next loop to preserve ordering.
1234    /// 7. This task exits after all pending proofs are processed.
1235    #[instrument(
1236        level = "debug",
1237        name = "MultiProofTask::run",
1238        target = "engine::tree::payload_processor::multiproof",
1239        skip_all
1240    )]
1241    pub(crate) fn run(mut self) {
1242        let mut ctx = MultiproofBatchCtx::new(Instant::now());
1243        let mut batch_metrics = MultiproofBatchMetrics::default();
1244
1245        // Main event loop; select_biased! prioritizes proof results over control messages.
1246        // Labeled so inner match arms can `break 'main` once all work is complete.
1247        'main: loop {
1248            trace!(target: "engine::tree::payload_processor::multiproof", "entering main channel receiving loop");
1249
1250            if let Some(msg) = ctx.pending_msg.take() {
1251                if self.process_multiproof_message(msg, &mut ctx, &mut batch_metrics) {
1252                    break 'main;
1253                }
1254                continue;
1255            }
1256
1257            // Use select_biased! to prioritize proof results over new requests.
1258            // This prevents new work from starving completed proofs and keeps workers healthy.
1259            crossbeam_channel::select_biased! {
1260                recv(self.proof_result_rx) -> proof_msg => {
1261                    match proof_msg {
1262                        Ok(proof_result) => {
1263                            batch_metrics.proofs_processed += 1;
1264
1265                            self.metrics
1266                                .proof_calculation_duration_histogram
1267                                .record(proof_result.elapsed);
1268
1269                            self.multiproof_manager.on_calculation_complete();
1270
1271                            // Convert ProofResultMessage to SparseTrieUpdate
1272                            match proof_result.result {
1273                                Ok(proof_result_data) => {
1274                                    trace!(
1275                                        target: "engine::tree::payload_processor::multiproof",
1276                                        sequence = proof_result.sequence_number,
1277                                        total_proofs = batch_metrics.proofs_processed,
1278                                        "Processing calculated proof from worker"
1279                                    );
1280
1281                                    let update = SparseTrieUpdate {
1282                                        state: proof_result.state,
1283                                        multiproof: proof_result_data.into_multiproof(),
1284                                    };
1285
1286                                    if let Some(combined_update) =
1287                                        self.on_proof(proof_result.sequence_number, update)
1288                                    {
1289                                        let _ = self.to_sparse_trie.send(combined_update);
1290                                    }
1291                                }
1292                                Err(error) => {
1293                                    error!(target: "engine::tree::payload_processor::multiproof", ?error, "proof calculation error from worker");
1294                                    return
1295                                }
1296                            }
1297
1298                            if self.is_done(
1299                                batch_metrics.proofs_processed,
1300                                batch_metrics.state_update_proofs_requested,
1301                                batch_metrics.prefetch_proofs_requested,
1302                                ctx.updates_finished(),
1303                            ) {
1304                                debug!(
1305                                    target: "engine::tree::payload_processor::multiproof",
1306                                    "State updates finished and all proofs processed, ending calculation"
1307                                );
1308                                break 'main
1309                            }
1310                        }
1311                        Err(_) => {
1312                            error!(target: "engine::tree::payload_processor::multiproof", "Proof result channel closed unexpectedly");
1313                            return
1314                        }
1315                    }
1316                },
1317                recv(self.rx) -> message => {
1318                    let msg = match message {
1319                        Ok(m) => m,
1320                        Err(_) => {
1321                            error!(target: "engine::tree::payload_processor::multiproof", "State root related message channel closed unexpectedly");
1322                            return
1323                        }
1324                    };
1325
1326                    if self.process_multiproof_message(msg, &mut ctx, &mut batch_metrics) {
1327                        break 'main;
1328                    }
1329                }
1330            }
1331        }
1332
1333        debug!(
1334            target: "engine::tree::payload_processor::multiproof",
1335            total_updates = batch_metrics.state_update_proofs_requested,
1336            total_proofs = batch_metrics.proofs_processed,
1337            total_time = ?ctx.first_update_time.map(|t|t.elapsed()),
1338            time_since_updates_finished = ?ctx.updates_finished_time.map(|t|t.elapsed()),
1339            "All proofs processed, ending calculation"
1340        );
1341
1342        // update total metrics on finish
1343        self.metrics
1344            .state_updates_received_histogram
1345            .record(batch_metrics.state_update_proofs_requested as f64);
1346        self.metrics.proofs_processed_histogram.record(batch_metrics.proofs_processed as f64);
1347        if let Some(total_time) = ctx.first_update_time.map(|t| t.elapsed()) {
1348            self.metrics.multiproof_task_total_duration_histogram.record(total_time);
1349        }
1350
1351        if let Some(updates_finished_time) = ctx.updates_finished_time {
1352            self.metrics
1353                .last_proof_wait_time_histogram
1354                .record(updates_finished_time.elapsed().as_secs_f64());
1355        }
1356    }
1357}
1358
1359/// Context for multiproof message batching loop.
1360///
1361/// Contains processing state that persists across loop iterations.
1362struct MultiproofBatchCtx {
1363    /// Buffers a non-matching message type encountered during batching.
1364    /// Processed first in next iteration to preserve ordering while allowing same-type
1365    /// messages to batch.
1366    pending_msg: Option<MultiProofMessage>,
1367    /// Timestamp when the first state update or prefetch was received.
1368    first_update_time: Option<Instant>,
1369    /// Timestamp before the first state update or prefetch was received.
1370    start: Instant,
1371    /// Timestamp when state updates finished. `Some` indicates all state updates have been
1372    /// received.
1373    updates_finished_time: Option<Instant>,
1374    /// Reusable buffer for accumulating prefetch targets during batching.
1375    accumulated_prefetch_targets: Vec<MultiProofTargets>,
1376    /// Reusable buffer for accumulating state updates during batching.
1377    accumulated_state_updates: Vec<(StateChangeSource, EvmState)>,
1378}
1379
1380impl MultiproofBatchCtx {
1381    /// Creates a new batch context with the given start time.
1382    fn new(start: Instant) -> Self {
1383        Self {
1384            pending_msg: None,
1385            first_update_time: None,
1386            start,
1387            updates_finished_time: None,
1388            accumulated_prefetch_targets: Vec::with_capacity(PREFETCH_MAX_BATCH_MESSAGES),
1389            accumulated_state_updates: Vec::with_capacity(STATE_UPDATE_BATCH_PREALLOC),
1390        }
1391    }
1392
1393    /// Returns `true` if all state updates have been received.
1394    const fn updates_finished(&self) -> bool {
1395        self.updates_finished_time.is_some()
1396    }
1397}
1398
1399/// Counters for tracking proof requests and processing.
1400#[derive(Default)]
1401struct MultiproofBatchMetrics {
1402    /// Number of proofs that have been processed.
1403    proofs_processed: u64,
1404    /// Number of state update proofs requested.
1405    state_update_proofs_requested: u64,
1406    /// Number of prefetch proofs requested.
1407    prefetch_proofs_requested: u64,
1408}
1409
1410/// Returns accounts only with those storages that were not already fetched, and
1411/// if there are no such storages and the account itself was already fetched, the
1412/// account shouldn't be included.
1413fn get_proof_targets(
1414    state_update: &HashedPostState,
1415    fetched_proof_targets: &MultiProofTargets,
1416    multi_added_removed_keys: &MultiAddedRemovedKeys,
1417) -> MultiProofTargets {
1418    let mut targets = MultiProofTargets::default();
1419
1420    // first collect all new accounts (not previously fetched)
1421    for &hashed_address in state_update.accounts.keys() {
1422        if !fetched_proof_targets.contains_key(&hashed_address) {
1423            targets.insert(hashed_address, HashSet::default());
1424        }
1425    }
1426
1427    // then process storage slots for all accounts in the state update
1428    for (hashed_address, storage) in &state_update.storages {
1429        let fetched = fetched_proof_targets.get(hashed_address);
1430        let storage_added_removed_keys = multi_added_removed_keys.get_storage(hashed_address);
1431        let mut changed_slots = storage
1432            .storage
1433            .keys()
1434            .filter(|slot| {
1435                !fetched.is_some_and(|f| f.contains(*slot)) ||
1436                    storage_added_removed_keys.is_some_and(|k| k.is_removed(slot))
1437            })
1438            .peekable();
1439
1440        // If the storage is wiped, we still need to fetch the account proof.
1441        if storage.wiped && fetched.is_none() {
1442            targets.entry(*hashed_address).or_default();
1443        }
1444
1445        if changed_slots.peek().is_some() {
1446            targets.entry(*hashed_address).or_default().extend(changed_slots);
1447        }
1448    }
1449
1450    targets
1451}
1452
1453/// Dispatches work items as a single unit or in chunks based on target size and worker
1454/// availability.
1455#[allow(clippy::too_many_arguments)]
1456fn dispatch_with_chunking<T, I>(
1457    items: T,
1458    chunking_len: usize,
1459    chunk_size: Option<usize>,
1460    max_targets_for_chunking: usize,
1461    available_account_workers: usize,
1462    available_storage_workers: usize,
1463    chunker: impl FnOnce(T, usize) -> I,
1464    mut dispatch: impl FnMut(T),
1465) -> usize
1466where
1467    I: IntoIterator<Item = T>,
1468{
1469    let should_chunk = chunking_len > max_targets_for_chunking ||
1470        available_account_workers > 1 ||
1471        available_storage_workers > 1;
1472
1473    if should_chunk &&
1474        let Some(chunk_size) = chunk_size &&
1475        chunking_len > chunk_size
1476    {
1477        let mut num_chunks = 0usize;
1478        for chunk in chunker(items, chunk_size) {
1479            dispatch(chunk);
1480            num_chunks += 1;
1481        }
1482        return num_chunks;
1483    }
1484
1485    dispatch(items);
1486    1
1487}
1488
1489/// Checks whether two state updates can be merged in a batch.
1490///
1491/// Transaction updates with the same transaction ID (`StateChangeSource::Transaction(id)`)
1492/// are safe to merge because they originate from the same logical execution and can be
1493/// coalesced to amortize proof work.
1494fn can_batch_state_update(
1495    batch_source: StateChangeSource,
1496    batch_update: &EvmState,
1497    next_source: StateChangeSource,
1498    next_update: &EvmState,
1499) -> bool {
1500    if !same_state_change_source(batch_source, next_source) {
1501        return false;
1502    }
1503
1504    match (batch_source, next_source) {
1505        (StateChangeSource::PreBlock(_), StateChangeSource::PreBlock(_)) |
1506        (StateChangeSource::PostBlock(_), StateChangeSource::PostBlock(_)) => {
1507            batch_update == next_update
1508        }
1509        _ => true,
1510    }
1511}
1512
1513/// Checks whether two state change sources refer to the same origin.
1514fn same_state_change_source(lhs: StateChangeSource, rhs: StateChangeSource) -> bool {
1515    match (lhs, rhs) {
1516        (StateChangeSource::Transaction(a), StateChangeSource::Transaction(b)) => a == b,
1517        (StateChangeSource::PreBlock(a), StateChangeSource::PreBlock(b)) => {
1518            mem::discriminant(&a) == mem::discriminant(&b)
1519        }
1520        (StateChangeSource::PostBlock(a), StateChangeSource::PostBlock(b)) => {
1521            mem::discriminant(&a) == mem::discriminant(&b)
1522        }
1523        _ => false,
1524    }
1525}
1526
1527/// Estimates target count from `EvmState` for batching decisions.
1528fn estimate_evm_state_targets(state: &EvmState) -> usize {
1529    state
1530        .values()
1531        .filter(|account| account.is_touched())
1532        .map(|account| {
1533            let changed_slots = account.storage.iter().filter(|(_, v)| v.is_changed()).count();
1534            1 + changed_slots
1535        })
1536        .sum()
1537}
1538
1539#[cfg(test)]
1540mod tests {
1541    use super::*;
1542    use alloy_primitives::map::B256Set;
1543    use reth_provider::{
1544        providers::OverlayStateProviderFactory, test_utils::create_test_provider_factory,
1545        BlockReader, DatabaseProviderFactory, PruneCheckpointReader, StageCheckpointReader,
1546        TrieReader,
1547    };
1548    use reth_trie::MultiProof;
1549    use reth_trie_parallel::proof_task::{ProofTaskCtx, ProofWorkerHandle};
1550    use revm_primitives::{B256, U256};
1551    use std::sync::OnceLock;
1552    use tokio::runtime::{Handle, Runtime};
1553
1554    /// Get a handle to the test runtime, creating it if necessary
1555    fn get_test_runtime_handle() -> Handle {
1556        static TEST_RT: OnceLock<Runtime> = OnceLock::new();
1557        TEST_RT
1558            .get_or_init(|| {
1559                tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap()
1560            })
1561            .handle()
1562            .clone()
1563    }
1564
1565    fn create_test_state_root_task<F>(factory: F) -> MultiProofTask
1566    where
1567        F: DatabaseProviderFactory<
1568                Provider: BlockReader + TrieReader + StageCheckpointReader + PruneCheckpointReader,
1569            > + Clone
1570            + Send
1571            + 'static,
1572    {
1573        let rt_handle = get_test_runtime_handle();
1574        let overlay_factory = OverlayStateProviderFactory::new(factory);
1575        let task_ctx = ProofTaskCtx::new(overlay_factory);
1576        let proof_handle = ProofWorkerHandle::new(rt_handle, task_ctx, 1, 1);
1577        let (to_sparse_trie, _receiver) = std::sync::mpsc::channel();
1578
1579        MultiProofTask::new(proof_handle, to_sparse_trie, Some(1))
1580    }
1581
1582    #[test]
1583    fn test_add_proof_in_sequence() {
1584        let mut sequencer = ProofSequencer::default();
1585        let proof1 = MultiProof::default();
1586        let proof2 = MultiProof::default();
1587        sequencer.next_sequence = 2;
1588
1589        let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof1).unwrap());
1590        assert_eq!(ready.len(), 1);
1591        assert!(!sequencer.has_pending());
1592
1593        let ready = sequencer.add_proof(1, SparseTrieUpdate::from_multiproof(proof2).unwrap());
1594        assert_eq!(ready.len(), 1);
1595        assert!(!sequencer.has_pending());
1596    }
1597
1598    #[test]
1599    fn test_add_proof_out_of_order() {
1600        let mut sequencer = ProofSequencer::default();
1601        let proof1 = MultiProof::default();
1602        let proof2 = MultiProof::default();
1603        let proof3 = MultiProof::default();
1604        sequencer.next_sequence = 3;
1605
1606        let ready = sequencer.add_proof(2, SparseTrieUpdate::from_multiproof(proof3).unwrap());
1607        assert_eq!(ready.len(), 0);
1608        assert!(sequencer.has_pending());
1609
1610        let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof1).unwrap());
1611        assert_eq!(ready.len(), 1);
1612        assert!(sequencer.has_pending());
1613
1614        let ready = sequencer.add_proof(1, SparseTrieUpdate::from_multiproof(proof2).unwrap());
1615        assert_eq!(ready.len(), 2);
1616        assert!(!sequencer.has_pending());
1617    }
1618
1619    #[test]
1620    fn test_add_proof_with_gaps() {
1621        let mut sequencer = ProofSequencer::default();
1622        let proof1 = MultiProof::default();
1623        let proof3 = MultiProof::default();
1624        sequencer.next_sequence = 3;
1625
1626        let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof1).unwrap());
1627        assert_eq!(ready.len(), 1);
1628
1629        let ready = sequencer.add_proof(2, SparseTrieUpdate::from_multiproof(proof3).unwrap());
1630        assert_eq!(ready.len(), 0);
1631        assert!(sequencer.has_pending());
1632    }
1633
1634    #[test]
1635    fn test_add_proof_duplicate_sequence() {
1636        let mut sequencer = ProofSequencer::default();
1637        let proof1 = MultiProof::default();
1638        let proof2 = MultiProof::default();
1639
1640        let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof1).unwrap());
1641        assert_eq!(ready.len(), 1);
1642
1643        let ready = sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proof2).unwrap());
1644        assert_eq!(ready.len(), 0);
1645        assert!(!sequencer.has_pending());
1646    }
1647
1648    #[test]
1649    fn test_add_proof_batch_processing() {
1650        let mut sequencer = ProofSequencer::default();
1651        let proofs: Vec<_> = (0..5).map(|_| MultiProof::default()).collect();
1652        sequencer.next_sequence = 5;
1653
1654        sequencer.add_proof(4, SparseTrieUpdate::from_multiproof(proofs[4].clone()).unwrap());
1655        sequencer.add_proof(2, SparseTrieUpdate::from_multiproof(proofs[2].clone()).unwrap());
1656        sequencer.add_proof(1, SparseTrieUpdate::from_multiproof(proofs[1].clone()).unwrap());
1657        sequencer.add_proof(3, SparseTrieUpdate::from_multiproof(proofs[3].clone()).unwrap());
1658
1659        let ready =
1660            sequencer.add_proof(0, SparseTrieUpdate::from_multiproof(proofs[0].clone()).unwrap());
1661        assert_eq!(ready.len(), 5);
1662        assert!(!sequencer.has_pending());
1663    }
1664
1665    fn create_get_proof_targets_state() -> HashedPostState {
1666        let mut state = HashedPostState::default();
1667
1668        let addr1 = B256::random();
1669        let addr2 = B256::random();
1670        state.accounts.insert(addr1, Some(Default::default()));
1671        state.accounts.insert(addr2, Some(Default::default()));
1672
1673        let mut storage = HashedStorage::default();
1674        let slot1 = B256::random();
1675        let slot2 = B256::random();
1676        storage.storage.insert(slot1, U256::ZERO);
1677        storage.storage.insert(slot2, U256::from(1));
1678        state.storages.insert(addr1, storage);
1679
1680        state
1681    }
1682
1683    #[test]
1684    fn test_get_proof_targets_new_account_targets() {
1685        let state = create_get_proof_targets_state();
1686        let fetched = MultiProofTargets::default();
1687
1688        let targets = get_proof_targets(&state, &fetched, &MultiAddedRemovedKeys::new());
1689
1690        // should return all accounts as targets since nothing was fetched before
1691        assert_eq!(targets.len(), state.accounts.len());
1692        for addr in state.accounts.keys() {
1693            assert!(targets.contains_key(addr));
1694        }
1695    }
1696
1697    #[test]
1698    fn test_get_proof_targets_new_storage_targets() {
1699        let state = create_get_proof_targets_state();
1700        let fetched = MultiProofTargets::default();
1701
1702        let targets = get_proof_targets(&state, &fetched, &MultiAddedRemovedKeys::new());
1703
1704        // verify storage slots are included for accounts with storage
1705        for (addr, storage) in &state.storages {
1706            assert!(targets.contains_key(addr));
1707            let target_slots = &targets[addr];
1708            assert_eq!(target_slots.len(), storage.storage.len());
1709            for slot in storage.storage.keys() {
1710                assert!(target_slots.contains(slot));
1711            }
1712        }
1713    }
1714
1715    #[test]
1716    fn test_get_proof_targets_filter_already_fetched_accounts() {
1717        let state = create_get_proof_targets_state();
1718        let mut fetched = MultiProofTargets::default();
1719
1720        // select an account that has no storage updates
1721        let fetched_addr = state
1722            .accounts
1723            .keys()
1724            .find(|&&addr| !state.storages.contains_key(&addr))
1725            .expect("Should have an account without storage");
1726
1727        // mark the account as already fetched
1728        fetched.insert(*fetched_addr, HashSet::default());
1729
1730        let targets = get_proof_targets(&state, &fetched, &MultiAddedRemovedKeys::new());
1731
1732        // should not include the already fetched account since it has no storage updates
1733        assert!(!targets.contains_key(fetched_addr));
1734        // other accounts should still be included
1735        assert_eq!(targets.len(), state.accounts.len() - 1);
1736    }
1737
1738    #[test]
1739    fn test_get_proof_targets_filter_already_fetched_storage() {
1740        let state = create_get_proof_targets_state();
1741        let mut fetched = MultiProofTargets::default();
1742
1743        // mark one storage slot as already fetched
1744        let (addr, storage) = state.storages.iter().next().unwrap();
1745        let mut fetched_slots = HashSet::default();
1746        let fetched_slot = *storage.storage.keys().next().unwrap();
1747        fetched_slots.insert(fetched_slot);
1748        fetched.insert(*addr, fetched_slots);
1749
1750        let targets = get_proof_targets(&state, &fetched, &MultiAddedRemovedKeys::new());
1751
1752        // should not include the already fetched storage slot
1753        let target_slots = &targets[addr];
1754        assert!(!target_slots.contains(&fetched_slot));
1755        assert_eq!(target_slots.len(), storage.storage.len() - 1);
1756    }
1757
1758    #[test]
1759    fn test_get_proof_targets_empty_state() {
1760        let state = HashedPostState::default();
1761        let fetched = MultiProofTargets::default();
1762
1763        let targets = get_proof_targets(&state, &fetched, &MultiAddedRemovedKeys::new());
1764
1765        assert!(targets.is_empty());
1766    }
1767
1768    #[test]
1769    fn test_get_proof_targets_mixed_fetched_state() {
1770        let mut state = HashedPostState::default();
1771        let mut fetched = MultiProofTargets::default();
1772
1773        let addr1 = B256::random();
1774        let addr2 = B256::random();
1775        let slot1 = B256::random();
1776        let slot2 = B256::random();
1777
1778        state.accounts.insert(addr1, Some(Default::default()));
1779        state.accounts.insert(addr2, Some(Default::default()));
1780
1781        let mut storage = HashedStorage::default();
1782        storage.storage.insert(slot1, U256::ZERO);
1783        storage.storage.insert(slot2, U256::from(1));
1784        state.storages.insert(addr1, storage);
1785
1786        let mut fetched_slots = HashSet::default();
1787        fetched_slots.insert(slot1);
1788        fetched.insert(addr1, fetched_slots);
1789
1790        let targets = get_proof_targets(&state, &fetched, &MultiAddedRemovedKeys::new());
1791
1792        assert!(targets.contains_key(&addr2));
1793        assert!(!targets[&addr1].contains(&slot1));
1794        assert!(targets[&addr1].contains(&slot2));
1795    }
1796
1797    #[test]
1798    fn test_get_proof_targets_unmodified_account_with_storage() {
1799        let mut state = HashedPostState::default();
1800        let fetched = MultiProofTargets::default();
1801
1802        let addr = B256::random();
1803        let slot1 = B256::random();
1804        let slot2 = B256::random();
1805
1806        // don't add the account to state.accounts (simulating unmodified account)
1807        // but add storage updates for this account
1808        let mut storage = HashedStorage::default();
1809        storage.storage.insert(slot1, U256::from(1));
1810        storage.storage.insert(slot2, U256::from(2));
1811        state.storages.insert(addr, storage);
1812
1813        assert!(!state.accounts.contains_key(&addr));
1814        assert!(!fetched.contains_key(&addr));
1815
1816        let targets = get_proof_targets(&state, &fetched, &MultiAddedRemovedKeys::new());
1817
1818        // verify that we still get the storage slots for the unmodified account
1819        assert!(targets.contains_key(&addr));
1820
1821        let target_slots = &targets[&addr];
1822        assert_eq!(target_slots.len(), 2);
1823        assert!(target_slots.contains(&slot1));
1824        assert!(target_slots.contains(&slot2));
1825    }
1826
1827    #[test]
1828    fn test_get_prefetch_proof_targets_no_duplicates() {
1829        let test_provider_factory = create_test_provider_factory();
1830        let mut test_state_root_task = create_test_state_root_task(test_provider_factory);
1831
1832        // populate some targets
1833        let mut targets = MultiProofTargets::default();
1834        let addr1 = B256::random();
1835        let addr2 = B256::random();
1836        let slot1 = B256::random();
1837        let slot2 = B256::random();
1838        targets.insert(addr1, std::iter::once(slot1).collect());
1839        targets.insert(addr2, std::iter::once(slot2).collect());
1840
1841        let prefetch_proof_targets =
1842            test_state_root_task.get_prefetch_proof_targets(targets.clone());
1843
1844        // check that the prefetch proof targets are the same because there are no fetched proof
1845        // targets yet
1846        assert_eq!(prefetch_proof_targets, targets);
1847
1848        // add a different addr and slot to fetched proof targets
1849        let addr3 = B256::random();
1850        let slot3 = B256::random();
1851        test_state_root_task.fetched_proof_targets.insert(addr3, std::iter::once(slot3).collect());
1852
1853        let prefetch_proof_targets =
1854            test_state_root_task.get_prefetch_proof_targets(targets.clone());
1855
1856        // check that the prefetch proof targets are the same because the fetched proof targets
1857        // don't overlap with the prefetch targets
1858        assert_eq!(prefetch_proof_targets, targets);
1859    }
1860
1861    #[test]
1862    fn test_get_prefetch_proof_targets_remove_subset() {
1863        let test_provider_factory = create_test_provider_factory();
1864        let mut test_state_root_task = create_test_state_root_task(test_provider_factory);
1865
1866        // populate some targe
1867        let mut targets = MultiProofTargets::default();
1868        let addr1 = B256::random();
1869        let addr2 = B256::random();
1870        let slot1 = B256::random();
1871        let slot2 = B256::random();
1872        targets.insert(addr1, std::iter::once(slot1).collect());
1873        targets.insert(addr2, std::iter::once(slot2).collect());
1874
1875        // add a subset of the first target to fetched proof targets
1876        test_state_root_task.fetched_proof_targets.insert(addr1, std::iter::once(slot1).collect());
1877
1878        let prefetch_proof_targets =
1879            test_state_root_task.get_prefetch_proof_targets(targets.clone());
1880
1881        // check that the prefetch proof targets do not include the subset
1882        assert_eq!(prefetch_proof_targets.len(), 1);
1883        assert!(!prefetch_proof_targets.contains_key(&addr1));
1884        assert!(prefetch_proof_targets.contains_key(&addr2));
1885
1886        // now add one more slot to the prefetch targets
1887        let slot3 = B256::random();
1888        targets.get_mut(&addr1).unwrap().insert(slot3);
1889
1890        let prefetch_proof_targets =
1891            test_state_root_task.get_prefetch_proof_targets(targets.clone());
1892
1893        // check that the prefetch proof targets do not include the subset
1894        // but include the new slot
1895        assert_eq!(prefetch_proof_targets.len(), 2);
1896        assert!(prefetch_proof_targets.contains_key(&addr1));
1897        assert_eq!(
1898            *prefetch_proof_targets.get(&addr1).unwrap(),
1899            std::iter::once(slot3).collect::<B256Set>()
1900        );
1901        assert!(prefetch_proof_targets.contains_key(&addr2));
1902        assert_eq!(
1903            *prefetch_proof_targets.get(&addr2).unwrap(),
1904            std::iter::once(slot2).collect::<B256Set>()
1905        );
1906    }
1907
1908    #[test]
1909    fn test_get_proof_targets_with_removed_storage_keys() {
1910        let mut state = HashedPostState::default();
1911        let mut fetched = MultiProofTargets::default();
1912        let mut multi_added_removed_keys = MultiAddedRemovedKeys::new();
1913
1914        let addr = B256::random();
1915        let slot1 = B256::random();
1916        let slot2 = B256::random();
1917
1918        // add account to state
1919        state.accounts.insert(addr, Some(Default::default()));
1920
1921        // add storage updates
1922        let mut storage = HashedStorage::default();
1923        storage.storage.insert(slot1, U256::from(100));
1924        storage.storage.insert(slot2, U256::from(200));
1925        state.storages.insert(addr, storage);
1926
1927        // mark slot1 as already fetched
1928        let mut fetched_slots = HashSet::default();
1929        fetched_slots.insert(slot1);
1930        fetched.insert(addr, fetched_slots);
1931
1932        // update multi_added_removed_keys to mark slot1 as removed
1933        let mut removed_state = HashedPostState::default();
1934        let mut removed_storage = HashedStorage::default();
1935        removed_storage.storage.insert(slot1, U256::ZERO); // U256::ZERO marks as removed
1936        removed_state.storages.insert(addr, removed_storage);
1937        multi_added_removed_keys.update_with_state(&removed_state);
1938
1939        let targets = get_proof_targets(&state, &fetched, &multi_added_removed_keys);
1940
1941        // slot1 should be included despite being fetched, because it's marked as removed
1942        assert!(targets.contains_key(&addr));
1943        let target_slots = &targets[&addr];
1944        assert_eq!(target_slots.len(), 2);
1945        assert!(target_slots.contains(&slot1)); // included because it's removed
1946        assert!(target_slots.contains(&slot2)); // included because it's not fetched
1947    }
1948
1949    #[test]
1950    fn test_get_proof_targets_with_wiped_storage() {
1951        let mut state = HashedPostState::default();
1952        let fetched = MultiProofTargets::default();
1953        let multi_added_removed_keys = MultiAddedRemovedKeys::new();
1954
1955        let addr = B256::random();
1956        let slot1 = B256::random();
1957
1958        // add account to state
1959        state.accounts.insert(addr, Some(Default::default()));
1960
1961        // add wiped storage
1962        let mut storage = HashedStorage::new(true);
1963        storage.storage.insert(slot1, U256::from(100));
1964        state.storages.insert(addr, storage);
1965
1966        let targets = get_proof_targets(&state, &fetched, &multi_added_removed_keys);
1967
1968        // account should be included because storage is wiped and account wasn't fetched
1969        assert!(targets.contains_key(&addr));
1970        let target_slots = &targets[&addr];
1971        assert_eq!(target_slots.len(), 1);
1972        assert!(target_slots.contains(&slot1));
1973    }
1974
1975    #[test]
1976    fn test_get_proof_targets_removed_keys_not_in_state_update() {
1977        let mut state = HashedPostState::default();
1978        let mut fetched = MultiProofTargets::default();
1979        let mut multi_added_removed_keys = MultiAddedRemovedKeys::new();
1980
1981        let addr = B256::random();
1982        let slot1 = B256::random();
1983        let slot2 = B256::random();
1984        let slot3 = B256::random();
1985
1986        // add account to state
1987        state.accounts.insert(addr, Some(Default::default()));
1988
1989        // add storage updates for slot1 and slot2 only
1990        let mut storage = HashedStorage::default();
1991        storage.storage.insert(slot1, U256::from(100));
1992        storage.storage.insert(slot2, U256::from(200));
1993        state.storages.insert(addr, storage);
1994
1995        // mark all slots as already fetched
1996        let mut fetched_slots = HashSet::default();
1997        fetched_slots.insert(slot1);
1998        fetched_slots.insert(slot2);
1999        fetched_slots.insert(slot3); // slot3 is fetched but not in state update
2000        fetched.insert(addr, fetched_slots);
2001
2002        // mark slot3 as removed (even though it's not in the state update)
2003        let mut removed_state = HashedPostState::default();
2004        let mut removed_storage = HashedStorage::default();
2005        removed_storage.storage.insert(slot3, U256::ZERO);
2006        removed_state.storages.insert(addr, removed_storage);
2007        multi_added_removed_keys.update_with_state(&removed_state);
2008
2009        let targets = get_proof_targets(&state, &fetched, &multi_added_removed_keys);
2010
2011        // only slots in the state update can be included, so slot3 should not appear
2012        assert!(!targets.contains_key(&addr));
2013    }
2014
2015    /// Verifies that consecutive prefetch proof messages are batched together.
2016    #[test]
2017    fn test_prefetch_proofs_batching() {
2018        let test_provider_factory = create_test_provider_factory();
2019        let mut task = create_test_state_root_task(test_provider_factory);
2020
2021        // send multiple messages
2022        let addr1 = B256::random();
2023        let addr2 = B256::random();
2024        let addr3 = B256::random();
2025
2026        let mut targets1 = MultiProofTargets::default();
2027        targets1.insert(addr1, HashSet::default());
2028
2029        let mut targets2 = MultiProofTargets::default();
2030        targets2.insert(addr2, HashSet::default());
2031
2032        let mut targets3 = MultiProofTargets::default();
2033        targets3.insert(addr3, HashSet::default());
2034
2035        let tx = task.state_root_message_sender();
2036        tx.send(MultiProofMessage::PrefetchProofs(targets1)).unwrap();
2037        tx.send(MultiProofMessage::PrefetchProofs(targets2)).unwrap();
2038        tx.send(MultiProofMessage::PrefetchProofs(targets3)).unwrap();
2039
2040        let proofs_requested =
2041            if let Ok(MultiProofMessage::PrefetchProofs(targets)) = task.rx.recv() {
2042                // simulate the batching logic
2043                let mut merged_targets = targets;
2044                let mut num_batched = 1;
2045                while let Ok(MultiProofMessage::PrefetchProofs(next_targets)) = task.rx.try_recv() {
2046                    merged_targets.extend(next_targets);
2047                    num_batched += 1;
2048                }
2049
2050                assert_eq!(num_batched, 3);
2051                assert_eq!(merged_targets.len(), 3);
2052                assert!(merged_targets.contains_key(&addr1));
2053                assert!(merged_targets.contains_key(&addr2));
2054                assert!(merged_targets.contains_key(&addr3));
2055
2056                task.on_prefetch_proof(merged_targets)
2057            } else {
2058                panic!("Expected PrefetchProofs message");
2059            };
2060
2061        assert_eq!(proofs_requested, 1);
2062    }
2063
2064    /// Verifies that consecutive state update messages from the same source are batched together.
2065    #[test]
2066    fn test_state_update_batching() {
2067        use alloy_evm::block::StateChangeSource;
2068        use revm_state::Account;
2069
2070        let test_provider_factory = create_test_provider_factory();
2071        let mut task = create_test_state_root_task(test_provider_factory);
2072
2073        // create multiple state updates
2074        let addr1 = alloy_primitives::Address::random();
2075        let addr2 = alloy_primitives::Address::random();
2076
2077        let mut update1 = EvmState::default();
2078        update1.insert(
2079            addr1,
2080            Account {
2081                info: revm_state::AccountInfo {
2082                    balance: U256::from(100),
2083                    nonce: 1,
2084                    code_hash: Default::default(),
2085                    code: Default::default(),
2086                },
2087                transaction_id: Default::default(),
2088                storage: Default::default(),
2089                status: revm_state::AccountStatus::Touched,
2090            },
2091        );
2092
2093        let mut update2 = EvmState::default();
2094        update2.insert(
2095            addr2,
2096            Account {
2097                info: revm_state::AccountInfo {
2098                    balance: U256::from(200),
2099                    nonce: 2,
2100                    code_hash: Default::default(),
2101                    code: Default::default(),
2102                },
2103                transaction_id: Default::default(),
2104                storage: Default::default(),
2105                status: revm_state::AccountStatus::Touched,
2106            },
2107        );
2108
2109        let source = StateChangeSource::Transaction(0);
2110
2111        let tx = task.state_root_message_sender();
2112        tx.send(MultiProofMessage::StateUpdate(source, update1.clone())).unwrap();
2113        tx.send(MultiProofMessage::StateUpdate(source, update2.clone())).unwrap();
2114
2115        let proofs_requested =
2116            if let Ok(MultiProofMessage::StateUpdate(_src, update)) = task.rx.recv() {
2117                let mut merged_update = update;
2118                let mut num_batched = 1;
2119
2120                while let Ok(MultiProofMessage::StateUpdate(_next_source, next_update)) =
2121                    task.rx.try_recv()
2122                {
2123                    merged_update.extend(next_update);
2124                    num_batched += 1;
2125                }
2126
2127                assert_eq!(num_batched, 2);
2128                assert_eq!(merged_update.len(), 2);
2129                assert!(merged_update.contains_key(&addr1));
2130                assert!(merged_update.contains_key(&addr2));
2131
2132                task.on_state_update(source, merged_update)
2133            } else {
2134                panic!("Expected StateUpdate message");
2135            };
2136        assert_eq!(proofs_requested, 1);
2137    }
2138
2139    /// Verifies that state updates from different sources are not batched together.
2140    #[test]
2141    fn test_state_update_batching_separates_sources() {
2142        use alloy_evm::block::StateChangeSource;
2143        use revm_state::Account;
2144
2145        let test_provider_factory = create_test_provider_factory();
2146        let task = create_test_state_root_task(test_provider_factory);
2147
2148        let addr_a1 = alloy_primitives::Address::random();
2149        let addr_b1 = alloy_primitives::Address::random();
2150        let addr_a2 = alloy_primitives::Address::random();
2151
2152        let create_state_update = |addr: alloy_primitives::Address, balance: u64| {
2153            let mut state = EvmState::default();
2154            state.insert(
2155                addr,
2156                Account {
2157                    info: revm_state::AccountInfo {
2158                        balance: U256::from(balance),
2159                        nonce: 1,
2160                        code_hash: Default::default(),
2161                        code: Default::default(),
2162                    },
2163                    transaction_id: Default::default(),
2164                    storage: Default::default(),
2165                    status: revm_state::AccountStatus::Touched,
2166                },
2167            );
2168            state
2169        };
2170
2171        let source_a = StateChangeSource::Transaction(1);
2172        let source_b = StateChangeSource::Transaction(2);
2173
2174        // Queue: A1 (immediate dispatch), B1 (batched), A2 (should become pending)
2175        let tx = task.state_root_message_sender();
2176        tx.send(MultiProofMessage::StateUpdate(source_a, create_state_update(addr_a1, 100)))
2177            .unwrap();
2178        tx.send(MultiProofMessage::StateUpdate(source_b, create_state_update(addr_b1, 200)))
2179            .unwrap();
2180        tx.send(MultiProofMessage::StateUpdate(source_a, create_state_update(addr_a2, 300)))
2181            .unwrap();
2182
2183        let mut pending_msg: Option<MultiProofMessage> = None;
2184
2185        if let Ok(MultiProofMessage::StateUpdate(first_source, _)) = task.rx.recv() {
2186            assert!(same_state_change_source(first_source, source_a));
2187
2188            // Simulate batching loop for remaining messages
2189            let mut accumulated_updates: Vec<(StateChangeSource, EvmState)> = Vec::new();
2190            let mut accumulated_targets = 0usize;
2191
2192            loop {
2193                if accumulated_targets >= STATE_UPDATE_MAX_BATCH_TARGETS {
2194                    break;
2195                }
2196                match task.rx.try_recv() {
2197                    Ok(MultiProofMessage::StateUpdate(next_source, next_update)) => {
2198                        if let Some((batch_source, batch_update)) = accumulated_updates.first() &&
2199                            !can_batch_state_update(
2200                                *batch_source,
2201                                batch_update,
2202                                next_source,
2203                                &next_update,
2204                            )
2205                        {
2206                            pending_msg =
2207                                Some(MultiProofMessage::StateUpdate(next_source, next_update));
2208                            break;
2209                        }
2210
2211                        let next_estimate = estimate_evm_state_targets(&next_update);
2212                        if next_estimate > STATE_UPDATE_MAX_BATCH_TARGETS {
2213                            pending_msg =
2214                                Some(MultiProofMessage::StateUpdate(next_source, next_update));
2215                            break;
2216                        }
2217                        if accumulated_targets + next_estimate > STATE_UPDATE_MAX_BATCH_TARGETS &&
2218                            !accumulated_updates.is_empty()
2219                        {
2220                            pending_msg =
2221                                Some(MultiProofMessage::StateUpdate(next_source, next_update));
2222                            break;
2223                        }
2224                        accumulated_targets += next_estimate;
2225                        accumulated_updates.push((next_source, next_update));
2226                    }
2227                    Ok(other_msg) => {
2228                        pending_msg = Some(other_msg);
2229                        break;
2230                    }
2231                    Err(_) => break,
2232                }
2233            }
2234
2235            assert_eq!(accumulated_updates.len(), 1, "Should only batch matching sources");
2236            let batch_source = accumulated_updates[0].0;
2237            assert!(same_state_change_source(batch_source, source_b));
2238
2239            let batch_source = accumulated_updates[0].0;
2240            let mut merged_update = accumulated_updates.remove(0).1;
2241            for (_, next_update) in accumulated_updates {
2242                merged_update.extend(next_update);
2243            }
2244
2245            assert!(
2246                same_state_change_source(batch_source, source_b),
2247                "Batch should use matching source"
2248            );
2249            assert!(merged_update.contains_key(&addr_b1));
2250            assert!(!merged_update.contains_key(&addr_a1));
2251            assert!(!merged_update.contains_key(&addr_a2));
2252        } else {
2253            panic!("Expected first StateUpdate");
2254        }
2255
2256        match pending_msg {
2257            Some(MultiProofMessage::StateUpdate(pending_source, pending_update)) => {
2258                assert!(same_state_change_source(pending_source, source_a));
2259                assert!(pending_update.contains_key(&addr_a2));
2260            }
2261            other => panic!("Expected pending StateUpdate with source_a, got {:?}", other),
2262        }
2263    }
2264
2265    /// Verifies that pre-block updates only batch when their payloads are identical.
2266    #[test]
2267    fn test_pre_block_updates_require_payload_match_to_batch() {
2268        use alloy_evm::block::{StateChangePreBlockSource, StateChangeSource};
2269        use revm_state::Account;
2270
2271        let test_provider_factory = create_test_provider_factory();
2272        let task = create_test_state_root_task(test_provider_factory);
2273
2274        let addr1 = alloy_primitives::Address::random();
2275        let addr2 = alloy_primitives::Address::random();
2276        let addr3 = alloy_primitives::Address::random();
2277
2278        let create_state_update = |addr: alloy_primitives::Address, balance: u64| {
2279            let mut state = EvmState::default();
2280            state.insert(
2281                addr,
2282                Account {
2283                    info: revm_state::AccountInfo {
2284                        balance: U256::from(balance),
2285                        nonce: 1,
2286                        code_hash: Default::default(),
2287                        code: Default::default(),
2288                    },
2289                    transaction_id: Default::default(),
2290                    storage: Default::default(),
2291                    status: revm_state::AccountStatus::Touched,
2292                },
2293            );
2294            state
2295        };
2296
2297        let source = StateChangeSource::PreBlock(StateChangePreBlockSource::BeaconRootContract);
2298
2299        // Queue: first update dispatched immediately, next two should not merge
2300        let tx = task.state_root_message_sender();
2301        tx.send(MultiProofMessage::StateUpdate(source, create_state_update(addr1, 100))).unwrap();
2302        tx.send(MultiProofMessage::StateUpdate(source, create_state_update(addr2, 200))).unwrap();
2303        tx.send(MultiProofMessage::StateUpdate(source, create_state_update(addr3, 300))).unwrap();
2304
2305        let mut pending_msg: Option<MultiProofMessage> = None;
2306
2307        if let Ok(MultiProofMessage::StateUpdate(first_source, first_update)) = task.rx.recv() {
2308            assert!(same_state_change_source(first_source, source));
2309            assert!(first_update.contains_key(&addr1));
2310
2311            let mut accumulated_updates: Vec<(StateChangeSource, EvmState)> = Vec::new();
2312            let mut accumulated_targets = 0usize;
2313
2314            loop {
2315                if accumulated_targets >= STATE_UPDATE_MAX_BATCH_TARGETS {
2316                    break;
2317                }
2318                match task.rx.try_recv() {
2319                    Ok(MultiProofMessage::StateUpdate(next_source, next_update)) => {
2320                        if let Some((batch_source, batch_update)) = accumulated_updates.first() &&
2321                            !can_batch_state_update(
2322                                *batch_source,
2323                                batch_update,
2324                                next_source,
2325                                &next_update,
2326                            )
2327                        {
2328                            pending_msg =
2329                                Some(MultiProofMessage::StateUpdate(next_source, next_update));
2330                            break;
2331                        }
2332
2333                        let next_estimate = estimate_evm_state_targets(&next_update);
2334                        if next_estimate > STATE_UPDATE_MAX_BATCH_TARGETS {
2335                            pending_msg =
2336                                Some(MultiProofMessage::StateUpdate(next_source, next_update));
2337                            break;
2338                        }
2339                        if accumulated_targets + next_estimate > STATE_UPDATE_MAX_BATCH_TARGETS &&
2340                            !accumulated_updates.is_empty()
2341                        {
2342                            pending_msg =
2343                                Some(MultiProofMessage::StateUpdate(next_source, next_update));
2344                            break;
2345                        }
2346                        accumulated_targets += next_estimate;
2347                        accumulated_updates.push((next_source, next_update));
2348                    }
2349                    Ok(other_msg) => {
2350                        pending_msg = Some(other_msg);
2351                        break;
2352                    }
2353                    Err(_) => break,
2354                }
2355            }
2356
2357            assert_eq!(
2358                accumulated_updates.len(),
2359                1,
2360                "Second pre-block update should not merge with a different payload"
2361            );
2362            let (batched_source, batched_update) = accumulated_updates.remove(0);
2363            assert!(same_state_change_source(batched_source, source));
2364            assert!(batched_update.contains_key(&addr2));
2365            assert!(!batched_update.contains_key(&addr3));
2366
2367            match pending_msg {
2368                Some(MultiProofMessage::StateUpdate(_, pending_update)) => {
2369                    assert!(pending_update.contains_key(&addr3));
2370                }
2371                other => panic!("Expected pending third pre-block update, got {:?}", other),
2372            }
2373        } else {
2374            panic!("Expected first StateUpdate");
2375        }
2376    }
2377
2378    /// Verifies that different message types arriving mid-batch are not lost and preserve order.
2379    #[test]
2380    fn test_batching_preserves_ordering_with_different_message_type() {
2381        use alloy_evm::block::StateChangeSource;
2382        use revm_state::Account;
2383
2384        let test_provider_factory = create_test_provider_factory();
2385        let task = create_test_state_root_task(test_provider_factory);
2386
2387        let addr1 = B256::random();
2388        let addr2 = B256::random();
2389        let addr3 = B256::random();
2390        let state_addr1 = alloy_primitives::Address::random();
2391        let state_addr2 = alloy_primitives::Address::random();
2392
2393        // Create PrefetchProofs targets
2394        let mut targets1 = MultiProofTargets::default();
2395        targets1.insert(addr1, HashSet::default());
2396
2397        let mut targets2 = MultiProofTargets::default();
2398        targets2.insert(addr2, HashSet::default());
2399
2400        let mut targets3 = MultiProofTargets::default();
2401        targets3.insert(addr3, HashSet::default());
2402
2403        // Create StateUpdate 1
2404        let mut state_update1 = EvmState::default();
2405        state_update1.insert(
2406            state_addr1,
2407            Account {
2408                info: revm_state::AccountInfo {
2409                    balance: U256::from(100),
2410                    nonce: 1,
2411                    code_hash: Default::default(),
2412                    code: Default::default(),
2413                },
2414                transaction_id: Default::default(),
2415                storage: Default::default(),
2416                status: revm_state::AccountStatus::Touched,
2417            },
2418        );
2419
2420        // Create StateUpdate 2
2421        let mut state_update2 = EvmState::default();
2422        state_update2.insert(
2423            state_addr2,
2424            Account {
2425                info: revm_state::AccountInfo {
2426                    balance: U256::from(200),
2427                    nonce: 2,
2428                    code_hash: Default::default(),
2429                    code: Default::default(),
2430                },
2431                transaction_id: Default::default(),
2432                storage: Default::default(),
2433                status: revm_state::AccountStatus::Touched,
2434            },
2435        );
2436
2437        let source = StateChangeSource::Transaction(42);
2438
2439        // Queue: [PrefetchProofs1, PrefetchProofs2, StateUpdate1, StateUpdate2, PrefetchProofs3]
2440        let tx = task.state_root_message_sender();
2441        tx.send(MultiProofMessage::PrefetchProofs(targets1)).unwrap();
2442        tx.send(MultiProofMessage::PrefetchProofs(targets2)).unwrap();
2443        tx.send(MultiProofMessage::StateUpdate(source, state_update1)).unwrap();
2444        tx.send(MultiProofMessage::StateUpdate(source, state_update2)).unwrap();
2445        tx.send(MultiProofMessage::PrefetchProofs(targets3.clone())).unwrap();
2446
2447        // Step 1: Receive and batch PrefetchProofs (should get targets1 + targets2)
2448        let mut pending_msg: Option<MultiProofMessage> = None;
2449        if let Ok(MultiProofMessage::PrefetchProofs(targets)) = task.rx.recv() {
2450            let mut merged_targets = targets;
2451            let mut num_batched = 1;
2452
2453            loop {
2454                match task.rx.try_recv() {
2455                    Ok(MultiProofMessage::PrefetchProofs(next_targets)) => {
2456                        merged_targets.extend(next_targets);
2457                        num_batched += 1;
2458                    }
2459                    Ok(other_msg) => {
2460                        // Store locally to preserve ordering (the fix)
2461                        pending_msg = Some(other_msg);
2462                        break;
2463                    }
2464                    Err(_) => break,
2465                }
2466            }
2467
2468            // Should have batched exactly 2 PrefetchProofs (not 3!)
2469            assert_eq!(num_batched, 2, "Should batch only until different message type");
2470            assert_eq!(merged_targets.len(), 2);
2471            assert!(merged_targets.contains_key(&addr1));
2472            assert!(merged_targets.contains_key(&addr2));
2473            assert!(!merged_targets.contains_key(&addr3), "addr3 should NOT be in first batch");
2474        } else {
2475            panic!("Expected PrefetchProofs message");
2476        }
2477
2478        // Step 2: The pending message should be StateUpdate1 (preserved ordering)
2479        match pending_msg {
2480            Some(MultiProofMessage::StateUpdate(_src, update)) => {
2481                assert!(update.contains_key(&state_addr1), "Should be first StateUpdate");
2482            }
2483            _ => panic!("StateUpdate1 was lost or reordered! The ordering fix is broken."),
2484        }
2485
2486        // Step 3: Next in channel should be StateUpdate2
2487        match task.rx.try_recv() {
2488            Ok(MultiProofMessage::StateUpdate(_src, update)) => {
2489                assert!(update.contains_key(&state_addr2), "Should be second StateUpdate");
2490            }
2491            _ => panic!("StateUpdate2 was lost!"),
2492        }
2493
2494        // Step 4: Next in channel should be PrefetchProofs3
2495        match task.rx.try_recv() {
2496            Ok(MultiProofMessage::PrefetchProofs(targets)) => {
2497                assert_eq!(targets.len(), 1);
2498                assert!(targets.contains_key(&addr3));
2499            }
2500            _ => panic!("PrefetchProofs3 was lost!"),
2501        }
2502    }
2503
2504    /// Verifies that a pending message is processed before the next loop iteration (ordering).
2505    #[test]
2506    fn test_pending_message_processed_before_next_iteration() {
2507        use alloy_evm::block::StateChangeSource;
2508        use revm_state::Account;
2509
2510        let test_provider_factory = create_test_provider_factory();
2511        let mut task = create_test_state_root_task(test_provider_factory);
2512
2513        // Queue: Prefetch1, StateUpdate, Prefetch2
2514        let prefetch_addr1 = B256::random();
2515        let prefetch_addr2 = B256::random();
2516        let mut prefetch1 = MultiProofTargets::default();
2517        prefetch1.insert(prefetch_addr1, HashSet::default());
2518        let mut prefetch2 = MultiProofTargets::default();
2519        prefetch2.insert(prefetch_addr2, HashSet::default());
2520
2521        let state_addr = alloy_primitives::Address::random();
2522        let mut state_update = EvmState::default();
2523        state_update.insert(
2524            state_addr,
2525            Account {
2526                info: revm_state::AccountInfo {
2527                    balance: U256::from(42),
2528                    nonce: 1,
2529                    code_hash: Default::default(),
2530                    code: Default::default(),
2531                },
2532                transaction_id: Default::default(),
2533                storage: Default::default(),
2534                status: revm_state::AccountStatus::Touched,
2535            },
2536        );
2537
2538        let source = StateChangeSource::Transaction(99);
2539
2540        let tx = task.state_root_message_sender();
2541        tx.send(MultiProofMessage::PrefetchProofs(prefetch1)).unwrap();
2542        tx.send(MultiProofMessage::StateUpdate(source, state_update)).unwrap();
2543        tx.send(MultiProofMessage::PrefetchProofs(prefetch2.clone())).unwrap();
2544
2545        let mut ctx = MultiproofBatchCtx::new(Instant::now());
2546        let mut batch_metrics = MultiproofBatchMetrics::default();
2547
2548        // First message: Prefetch1 batches; StateUpdate becomes pending.
2549        let first = task.rx.recv().unwrap();
2550        assert!(matches!(first, MultiProofMessage::PrefetchProofs(_)));
2551        assert!(!task.process_multiproof_message(first, &mut ctx, &mut batch_metrics));
2552        let pending = ctx.pending_msg.take().expect("pending message captured");
2553        assert!(matches!(pending, MultiProofMessage::StateUpdate(_, _)));
2554
2555        // Pending message should be handled before the next select loop.
2556        assert!(!task.process_multiproof_message(pending, &mut ctx, &mut batch_metrics));
2557
2558        // Prefetch2 should now be in pending_msg (captured by StateUpdate's batching loop).
2559        match ctx.pending_msg.take() {
2560            Some(MultiProofMessage::PrefetchProofs(targets)) => {
2561                assert_eq!(targets.len(), 1);
2562                assert!(targets.contains_key(&prefetch_addr2));
2563            }
2564            other => panic!("Expected remaining PrefetchProofs2 in pending_msg, got {:?}", other),
2565        }
2566    }
2567
2568    /// Verifies that pending messages from a previous batch drain get full batching treatment.
2569    #[test]
2570    fn test_pending_messages_get_full_batching_treatment() {
2571        // Queue: [Prefetch1, State1, State2, State3, Prefetch2]
2572        //
2573        // Expected behavior:
2574        // 1. recv() → Prefetch1
2575        // 2. try_recv() → State1 is different type → pending = State1, break
2576        // 3. Process Prefetch1
2577        // 4. Next iteration: pending = State1 → process with batching
2578        // 5. try_recv() → State2 same type → merge
2579        // 6. try_recv() → State3 same type → merge
2580        // 7. try_recv() → Prefetch2 different type → pending = Prefetch2, break
2581        // 8. Process merged State (1+2+3)
2582        //
2583        // Without the state-machine fix, State1 would be processed alone (no batching).
2584        use alloy_evm::block::StateChangeSource;
2585        use revm_state::Account;
2586
2587        let test_provider_factory = create_test_provider_factory();
2588        let task = create_test_state_root_task(test_provider_factory);
2589
2590        let prefetch_addr1 = B256::random();
2591        let prefetch_addr2 = B256::random();
2592        let state_addr1 = alloy_primitives::Address::random();
2593        let state_addr2 = alloy_primitives::Address::random();
2594        let state_addr3 = alloy_primitives::Address::random();
2595
2596        // Create Prefetch targets
2597        let mut prefetch1 = MultiProofTargets::default();
2598        prefetch1.insert(prefetch_addr1, HashSet::default());
2599
2600        let mut prefetch2 = MultiProofTargets::default();
2601        prefetch2.insert(prefetch_addr2, HashSet::default());
2602
2603        // Create StateUpdates
2604        let create_state_update = |addr: alloy_primitives::Address, balance: u64| {
2605            let mut state = EvmState::default();
2606            state.insert(
2607                addr,
2608                Account {
2609                    info: revm_state::AccountInfo {
2610                        balance: U256::from(balance),
2611                        nonce: 1,
2612                        code_hash: Default::default(),
2613                        code: Default::default(),
2614                    },
2615                    transaction_id: Default::default(),
2616                    storage: Default::default(),
2617                    status: revm_state::AccountStatus::Touched,
2618                },
2619            );
2620            state
2621        };
2622
2623        let source = StateChangeSource::Transaction(42);
2624
2625        // Queue: [Prefetch1, State1, State2, State3, Prefetch2]
2626        let tx = task.state_root_message_sender();
2627        tx.send(MultiProofMessage::PrefetchProofs(prefetch1.clone())).unwrap();
2628        tx.send(MultiProofMessage::StateUpdate(source, create_state_update(state_addr1, 100)))
2629            .unwrap();
2630        tx.send(MultiProofMessage::StateUpdate(source, create_state_update(state_addr2, 200)))
2631            .unwrap();
2632        tx.send(MultiProofMessage::StateUpdate(source, create_state_update(state_addr3, 300)))
2633            .unwrap();
2634        tx.send(MultiProofMessage::PrefetchProofs(prefetch2.clone())).unwrap();
2635
2636        // Simulate the state-machine loop behavior
2637        let mut pending_msg: Option<MultiProofMessage> = None;
2638
2639        // First iteration: recv() gets Prefetch1, drains until State1
2640        if let Ok(MultiProofMessage::PrefetchProofs(targets)) = task.rx.recv() {
2641            let mut merged_targets = targets;
2642            loop {
2643                match task.rx.try_recv() {
2644                    Ok(MultiProofMessage::PrefetchProofs(next_targets)) => {
2645                        merged_targets.extend(next_targets);
2646                    }
2647                    Ok(other_msg) => {
2648                        pending_msg = Some(other_msg);
2649                        break;
2650                    }
2651                    Err(_) => break,
2652                }
2653            }
2654            // Should have only Prefetch1 (State1 is different type)
2655            assert_eq!(merged_targets.len(), 1);
2656            assert!(merged_targets.contains_key(&prefetch_addr1));
2657        } else {
2658            panic!("Expected PrefetchProofs");
2659        }
2660
2661        // Pending should be State1
2662        assert!(matches!(pending_msg, Some(MultiProofMessage::StateUpdate(_, _))));
2663
2664        // Second iteration: process pending State1 WITH BATCHING
2665        // This is the key test - the pending message should drain State2 and State3
2666        if let Some(MultiProofMessage::StateUpdate(_src, first_update)) = pending_msg.take() {
2667            let mut merged_update = first_update;
2668            let mut num_batched = 1;
2669
2670            loop {
2671                match task.rx.try_recv() {
2672                    Ok(MultiProofMessage::StateUpdate(_src, next_update)) => {
2673                        merged_update.extend(next_update);
2674                        num_batched += 1;
2675                    }
2676                    Ok(other_msg) => {
2677                        pending_msg = Some(other_msg);
2678                        break;
2679                    }
2680                    Err(_) => break,
2681                }
2682            }
2683
2684            // THE KEY ASSERTION: pending State1 should have batched with State2 and State3
2685            assert_eq!(
2686                num_batched, 3,
2687                "Pending message should get full batching treatment and merge all 3 StateUpdates"
2688            );
2689            assert_eq!(merged_update.len(), 3, "Should have all 3 addresses in merged update");
2690            assert!(merged_update.contains_key(&state_addr1));
2691            assert!(merged_update.contains_key(&state_addr2));
2692            assert!(merged_update.contains_key(&state_addr3));
2693        } else {
2694            panic!("Expected pending StateUpdate");
2695        }
2696
2697        // Pending should now be Prefetch2
2698        match pending_msg {
2699            Some(MultiProofMessage::PrefetchProofs(targets)) => {
2700                assert_eq!(targets.len(), 1);
2701                assert!(targets.contains_key(&prefetch_addr2));
2702            }
2703            _ => panic!("Prefetch2 was lost!"),
2704        }
2705    }
2706}