1use alloy_eip7928::BAL_RETENTION_PERIOD_SLOTS;
2use alloy_eips::NumHash;
3use alloy_primitives::{BlockHash, BlockNumber, Bytes};
4use parking_lot::RwLock;
5use reth_prune_types::PruneMode;
6use reth_storage_api::{
7 BalNotification, BalNotificationStream, BalStore, GetBlockAccessListLimit, SealedBal,
8};
9use reth_storage_errors::provider::ProviderResult;
10use reth_tokio_util::EventSender;
11use std::{
12 collections::{BTreeMap, HashMap},
13 sync::Arc,
14};
15
16#[derive(Debug, Clone)]
18pub struct InMemoryBalStore {
19 config: BalConfig,
20 inner: Arc<RwLock<InMemoryBalStoreInner>>,
21 notifications: EventSender<BalNotification>,
22}
23
24impl InMemoryBalStore {
25 pub fn new(config: BalConfig) -> Self {
27 let notifications = EventSender::new(DEFAULT_BAL_NOTIFICATION_CHANNEL_SIZE);
28 Self {
29 config,
30 inner: Arc::new(RwLock::new(InMemoryBalStoreInner::default())),
31 notifications,
32 }
33 }
34}
35
36const DEFAULT_BAL_NOTIFICATION_CHANNEL_SIZE: usize = 256;
39
40impl Default for InMemoryBalStore {
41 fn default() -> Self {
42 Self::new(BalConfig::default())
43 }
44}
45
46#[derive(Debug, Clone, Copy, Eq, PartialEq)]
48pub struct BalConfig {
49 in_memory_retention: Option<PruneMode>,
51}
52
53impl BalConfig {
54 pub const DEFAULT_IN_MEMORY_RETENTION_DISTANCE: u64 = BAL_RETENTION_PERIOD_SLOTS;
56
57 pub const fn unbounded() -> Self {
59 Self { in_memory_retention: None }
60 }
61
62 pub const fn with_in_memory_retention_distance(blocks: u64) -> Self {
64 Self::with_in_memory_retention(PruneMode::Distance(blocks))
65 }
66
67 pub const fn with_in_memory_retention(in_memory_retention: PruneMode) -> Self {
69 Self { in_memory_retention: Some(in_memory_retention) }
70 }
71}
72
73impl Default for BalConfig {
74 fn default() -> Self {
75 Self::with_in_memory_retention_distance(Self::DEFAULT_IN_MEMORY_RETENTION_DISTANCE)
76 }
77}
78
79#[derive(Debug, Default)]
80struct InMemoryBalStoreInner {
81 entries: HashMap<BlockHash, BalEntry>,
82 hashes_by_number: BTreeMap<BlockNumber, Vec<BlockHash>>,
83 highest_block_number: Option<BlockNumber>,
84}
85
86impl InMemoryBalStoreInner {
87 fn insert(&mut self, block_hash: BlockHash, block_number: BlockNumber, bal: Bytes) {
89 let empty_block_number =
90 self.entries.insert(block_hash, BalEntry { block_number, bal }).and_then(|entry| {
91 let hashes = self.hashes_by_number.get_mut(&entry.block_number)?;
92 hashes.retain(|hash| *hash != block_hash);
93 hashes.is_empty().then_some(entry.block_number)
94 });
95
96 if let Some(block_number) = empty_block_number {
97 self.hashes_by_number.remove(&block_number);
98 }
99
100 self.hashes_by_number.entry(block_number).or_default().push(block_hash);
101 self.highest_block_number = Some(
102 self.highest_block_number.map_or(block_number, |highest| highest.max(block_number)),
103 );
104 }
105
106 fn prune(&mut self, prune_mode: Option<PruneMode>, tip: BlockNumber) -> usize {
108 let Some(prune_mode) = prune_mode else { return 0 };
109
110 let mut pruned = 0;
111 while let Some((&block_number, _)) = self.hashes_by_number.first_key_value() {
112 if !prune_mode.should_prune(block_number, tip) {
113 break
114 }
115
116 let Some((_, hashes)) = self.hashes_by_number.pop_first() else { break };
117 for hash in hashes {
118 pruned += usize::from(self.entries.remove(&hash).is_some());
119 }
120 }
121 pruned
122 }
123}
124
125#[derive(Debug)]
126struct BalEntry {
127 block_number: BlockNumber,
128 bal: Bytes,
129}
130
131impl BalStore for InMemoryBalStore {
132 fn insert(&self, num_hash: NumHash, bal: SealedBal) -> ProviderResult<()> {
133 let mut inner = self.inner.write();
134 inner.insert(num_hash.hash, num_hash.number, bal.clone_inner());
135 if let Some(highest_block_number) = inner.highest_block_number {
136 inner.prune(self.config.in_memory_retention, highest_block_number);
138 }
139 self.notifications.notify(BalNotification::new(num_hash, bal));
140 Ok(())
141 }
142
143 fn insert_many(&self, entries: Vec<(NumHash, SealedBal)>) -> ProviderResult<()> {
144 if entries.is_empty() {
145 return Ok(())
146 }
147
148 let mut inner = self.inner.write();
149 inner.entries.reserve(entries.len());
150 for (num_hash, bal) in &entries {
151 inner.insert(num_hash.hash, num_hash.number, bal.clone_inner());
152 }
153 if let Some(highest_block_number) = inner.highest_block_number {
154 inner.prune(self.config.in_memory_retention, highest_block_number);
155 }
156 drop(inner);
157
158 for (num_hash, bal) in entries {
159 self.notifications.notify(BalNotification::new(num_hash, bal));
160 }
161 Ok(())
162 }
163
164 fn flush(&self) -> ProviderResult<()> {
165 Ok(())
166 }
167
168 fn prune(&self, tip: BlockNumber) -> ProviderResult<usize> {
169 Ok(self.inner.write().prune(self.config.in_memory_retention, tip))
170 }
171
172 fn get_by_hashes(&self, block_hashes: &[BlockHash]) -> ProviderResult<Vec<Option<Bytes>>> {
173 let inner = self.inner.read();
174 let mut result = Vec::with_capacity(block_hashes.len());
175
176 for hash in block_hashes {
177 result.push(inner.entries.get(hash).map(|entry| entry.bal.clone()));
178 }
179
180 Ok(result)
181 }
182
183 fn append_by_hashes_with_limit(
184 &self,
185 block_hashes: &[BlockHash],
186 limit: GetBlockAccessListLimit,
187 out: &mut Vec<Option<Bytes>>,
188 ) -> ProviderResult<()> {
189 let inner = self.inner.read();
190 let mut size = 0;
191
192 for hash in block_hashes {
193 let bal = inner.entries.get(hash).map(|entry| entry.bal.clone());
194 size += bal.as_ref().map_or(1, |bytes| bytes.len());
195 out.push(bal);
196
197 if limit.exceeds(size) {
198 break
199 }
200 }
201
202 Ok(())
203 }
204
205 fn get_by_range(&self, _start: BlockNumber, _count: u64) -> ProviderResult<Vec<Bytes>> {
206 Ok(Vec::new())
207 }
208
209 fn bal_stream(&self) -> BalNotificationStream {
210 self.notifications.new_listener()
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use alloy_primitives::{keccak256, Sealed, B256};
218 use tokio_stream::StreamExt;
219
220 fn sealed_bal(bal: Bytes) -> SealedBal {
221 Sealed::new_unchecked(bal.clone(), keccak256(&bal))
222 }
223
224 #[test]
225 fn insert_and_lookup_by_hash() {
226 let store = InMemoryBalStore::default();
227 let hash = B256::random();
228 let missing = B256::random();
229 let bal = Bytes::from_static(b"bal");
230
231 store.insert(NumHash::new(1, hash), sealed_bal(bal.clone())).unwrap();
232
233 assert_eq!(store.get_by_hashes(&[hash, missing]).unwrap(), vec![Some(bal), None]);
234 }
235
236 #[test]
237 fn insert_many_and_lookup_by_hash() {
238 let store = InMemoryBalStore::default();
239 let hash0 = B256::random();
240 let hash1 = B256::random();
241 let bal0 = sealed_bal(Bytes::from_static(b"bal0"));
242 let bal1 = sealed_bal(Bytes::from_static(b"bal1"));
243
244 store
245 .insert_many(vec![
246 (NumHash::new(1, hash0), bal0.clone()),
247 (NumHash::new(2, hash1), bal1),
248 ])
249 .unwrap();
250
251 assert_eq!(
252 store.get_by_hashes(&[hash0, hash1]).unwrap(),
253 vec![Some(bal0.clone_inner()), Some(Bytes::from_static(b"bal1"))]
254 );
255 }
256
257 #[test]
258 fn range_lookup_is_empty() {
259 let store = InMemoryBalStore::default();
260
261 assert!(store.get_by_range(1, 10).unwrap().is_empty());
262 }
263
264 #[test]
265 fn flush_is_noop() {
266 let store = InMemoryBalStore::default();
267
268 store.flush().unwrap();
269 }
270
271 #[test]
272 fn limited_lookup_returns_prefix() {
273 let store = InMemoryBalStore::default();
274 let hash0 = B256::random();
275 let hash1 = B256::random();
276 let hash2 = B256::random();
277 let bal0 = Bytes::from_static(&[0xc1, 0x01]);
278 let bal1 = Bytes::from_static(&[0xc1, 0x02]);
279 let bal2 = Bytes::from_static(&[0xc1, 0x03]);
280
281 store.insert(NumHash::new(1, hash0), sealed_bal(bal0.clone())).unwrap();
282 store.insert(NumHash::new(2, hash1), sealed_bal(bal1.clone())).unwrap();
283 store.insert(NumHash::new(3, hash2), sealed_bal(bal2)).unwrap();
284
285 let limited = store
286 .get_by_hashes_with_limit(
287 &[hash0, hash1, hash2],
288 GetBlockAccessListLimit::ResponseSizeSoftLimit(2),
289 )
290 .unwrap();
291
292 assert_eq!(limited, vec![Some(bal0), Some(bal1)]);
293 }
294
295 #[test]
296 fn default_retention_prunes_old_bals() {
297 let store = InMemoryBalStore::default();
298 let old_hash = B256::random();
299 let retained_hash = B256::random();
300 let tip_hash = B256::random();
301 let old_bal = Bytes::from_static(b"old");
302 let retained_bal = Bytes::from_static(b"retained");
303 let tip_bal = Bytes::from_static(b"tip");
304
305 store.insert(NumHash::new(1, old_hash), sealed_bal(old_bal)).unwrap();
306 store
307 .insert(
308 NumHash::new(BAL_RETENTION_PERIOD_SLOTS, retained_hash),
309 sealed_bal(retained_bal.clone()),
310 )
311 .unwrap();
312 store
313 .insert(
314 NumHash::new(BAL_RETENTION_PERIOD_SLOTS + 2, tip_hash),
315 sealed_bal(tip_bal.clone()),
316 )
317 .unwrap();
318
319 assert_eq!(
320 store.get_by_hashes(&[old_hash, retained_hash, tip_hash]).unwrap(),
321 vec![None, Some(retained_bal), Some(tip_bal)]
322 );
323 }
324
325 #[test]
326 fn prune_uses_chain_tip() {
327 let store =
328 InMemoryBalStore::new(BalConfig::with_in_memory_retention(PruneMode::Distance(2)));
329 let old_hash = B256::random();
330 let retained_hash = B256::random();
331 let old_bal = Bytes::from_static(b"old");
332 let retained_bal = Bytes::from_static(b"retained");
333
334 store.insert(NumHash::new(7, old_hash), sealed_bal(old_bal)).unwrap();
335 store.insert(NumHash::new(8, retained_hash), sealed_bal(retained_bal.clone())).unwrap();
336
337 assert_eq!(store.prune(10).unwrap(), 1);
338 assert_eq!(
339 store.get_by_hashes(&[old_hash, retained_hash]).unwrap(),
340 vec![None, Some(retained_bal)]
341 );
342 }
343
344 #[test]
345 fn insert_prunes_from_highest_inserted_block() {
346 let store =
347 InMemoryBalStore::new(BalConfig::with_in_memory_retention(PruneMode::Distance(2)));
348 let old_hash = B256::random();
349 let high_hash = B256::random();
350 let late_hash = B256::random();
351 let high_bal = Bytes::from_static(b"high");
352 let late_bal = Bytes::from_static(b"late");
353
354 store.insert(NumHash::new(7, old_hash), sealed_bal(Bytes::from_static(b"old"))).unwrap();
355 store.insert(NumHash::new(10, high_hash), sealed_bal(high_bal.clone())).unwrap();
356 store.insert(NumHash::new(8, late_hash), sealed_bal(late_bal.clone())).unwrap();
357
358 assert_eq!(
359 store.get_by_hashes(&[old_hash, high_hash, late_hash]).unwrap(),
360 vec![None, Some(high_bal), Some(late_bal)]
361 );
362 }
363
364 #[test]
365 fn unbounded_retention_keeps_old_bals() {
366 let store = InMemoryBalStore::new(BalConfig::unbounded());
367 let old_hash = B256::random();
368 let tip_hash = B256::random();
369 let old_bal = Bytes::from_static(b"old");
370 let tip_bal = Bytes::from_static(b"tip");
371
372 store.insert(NumHash::new(1, old_hash), sealed_bal(old_bal.clone())).unwrap();
373 store
374 .insert(
375 NumHash::new(BAL_RETENTION_PERIOD_SLOTS + 1, tip_hash),
376 sealed_bal(tip_bal.clone()),
377 )
378 .unwrap();
379
380 assert_eq!(
381 store.get_by_hashes(&[old_hash, tip_hash]).unwrap(),
382 vec![Some(old_bal), Some(tip_bal)]
383 );
384 assert_eq!(store.prune(BAL_RETENTION_PERIOD_SLOTS + 2).unwrap(), 0);
385 }
386
387 #[test]
388 fn in_memory_retention_distance_prunes_old_bals() {
389 let store = InMemoryBalStore::new(BalConfig::with_in_memory_retention_distance(2));
390 let old_hash = B256::random();
391 let retained_hash = B256::random();
392 let tip_hash = B256::random();
393 let old_bal = Bytes::from_static(b"old");
394 let retained_bal = Bytes::from_static(b"retained");
395 let tip_bal = Bytes::from_static(b"tip");
396
397 store.insert(NumHash::new(1, old_hash), sealed_bal(old_bal)).unwrap();
398 store.insert(NumHash::new(2, retained_hash), sealed_bal(retained_bal.clone())).unwrap();
399 store.insert(NumHash::new(4, tip_hash), sealed_bal(tip_bal.clone())).unwrap();
400
401 assert_eq!(
402 store.get_by_hashes(&[old_hash, retained_hash, tip_hash]).unwrap(),
403 vec![None, Some(retained_bal), Some(tip_bal)]
404 );
405 }
406
407 #[test]
408 fn reinserting_hash_updates_number_index() {
409 let store =
410 InMemoryBalStore::new(BalConfig::with_in_memory_retention(PruneMode::Before(2)));
411 let hash = B256::random();
412 let bal = Bytes::from_static(b"bal");
413
414 store.insert(NumHash::new(1, hash), sealed_bal(Bytes::from_static(b"old"))).unwrap();
415 store.insert(NumHash::new(2, hash), sealed_bal(bal.clone())).unwrap();
416
417 assert_eq!(store.get_by_hashes(&[hash]).unwrap(), vec![Some(bal)]);
418 }
419
420 #[tokio::test]
421 async fn insert_notifies_subscribers() {
422 let store = InMemoryBalStore::default();
423 let hash = B256::random();
424 let block_number = 7;
425 let bal = Bytes::from_static(b"bal");
426 let mut stream = store.bal_stream();
427
428 let sealed_bal = sealed_bal(bal);
429
430 store.insert(NumHash::new(block_number, hash), sealed_bal.clone()).unwrap();
431
432 assert_eq!(
433 stream.next().await.unwrap(),
434 BalNotification::new(NumHash::new(block_number, hash), sealed_bal)
435 );
436 }
437
438 #[tokio::test]
439 async fn insert_many_notifies_subscribers() {
440 let store = InMemoryBalStore::default();
441 let mut stream = store.bal_stream();
442 let hash0 = B256::random();
443 let hash1 = B256::random();
444 let bal0 = sealed_bal(Bytes::from_static(b"bal0"));
445 let bal1 = sealed_bal(Bytes::from_static(b"bal1"));
446
447 store
448 .insert_many(vec![
449 (NumHash::new(1, hash0), bal0.clone()),
450 (NumHash::new(2, hash1), bal1.clone()),
451 ])
452 .unwrap();
453
454 assert_eq!(
455 stream.next().await.unwrap(),
456 BalNotification::new(NumHash::new(1, hash0), bal0)
457 );
458 assert_eq!(
459 stream.next().await.unwrap(),
460 BalNotification::new(NumHash::new(2, hash1), bal1)
461 );
462 }
463
464 #[test]
465 fn insert_without_subscribers_still_succeeds() {
466 let store = InMemoryBalStore::default();
467
468 assert!(store
469 .insert(NumHash::new(1, B256::random()), sealed_bal(Bytes::from_static(b"bal")))
470 .is_ok());
471 }
472
473 #[tokio::test]
474 async fn bal_stream_skips_lagged_notifications() {
475 let store = InMemoryBalStore::new(BalConfig::unbounded());
476 let mut stream = store.bal_stream();
477
478 for number in 0..=DEFAULT_BAL_NOTIFICATION_CHANNEL_SIZE as u64 {
479 store
480 .insert(
481 NumHash::new(number, B256::random()),
482 sealed_bal(Bytes::from(vec![number as u8])),
483 )
484 .unwrap();
485 }
486
487 let first = stream.next().await.unwrap();
488 let second = stream.next().await.unwrap();
489
490 assert_eq!(first.num_hash.number, 1);
491 assert_eq!(second.num_hash.number, 2);
492 }
493
494 #[tokio::test]
495 async fn cloned_store_shares_notification_channel() {
496 let store = InMemoryBalStore::default();
497 let clone = store.clone();
498 let hash = B256::random();
499 let block_number = 9;
500 let bal = Bytes::from_static(b"bal");
501 let mut stream = clone.bal_stream();
502
503 let sealed_bal = sealed_bal(bal);
504
505 store.insert(NumHash::new(block_number, hash), sealed_bal.clone()).unwrap();
506
507 assert_eq!(
508 stream.next().await.unwrap(),
509 BalNotification::new(NumHash::new(block_number, hash), sealed_bal)
510 );
511 }
512}