1use std::{collections::HashMap, future::Future, sync::Arc};
2
3use alloy_eips::BlockId;
4use alloy_primitives::{Address, U256};
5use async_trait::async_trait;
6use futures::StreamExt;
7use jsonrpsee::{core::RpcResult, PendingSubscriptionSink, SubscriptionMessage, SubscriptionSink};
8use jsonrpsee_types::ErrorObject;
9use reth_chain_state::{CanonStateNotificationStream, CanonStateSubscriptions};
10use reth_errors::RethResult;
11use reth_primitives_traits::NodePrimitives;
12use reth_rpc_api::RethApiServer;
13use reth_rpc_eth_types::{EthApiError, EthResult};
14use reth_rpc_server_types::result::internal_rpc_err;
15use reth_storage_api::{BlockReaderIdExt, ChangeSetReader, StateProviderFactory};
16use reth_tasks::TaskSpawner;
17use tokio::sync::oneshot;
18
19pub struct RethApi<Provider> {
23 inner: Arc<RethApiInner<Provider>>,
24}
25
26impl<Provider> RethApi<Provider> {
29 pub fn provider(&self) -> &Provider {
31 &self.inner.provider
32 }
33
34 pub fn new(provider: Provider, task_spawner: Box<dyn TaskSpawner>) -> Self {
36 let inner = Arc::new(RethApiInner { provider, task_spawner });
37 Self { inner }
38 }
39}
40
41impl<Provider> RethApi<Provider>
42where
43 Provider: BlockReaderIdExt + ChangeSetReader + StateProviderFactory + 'static,
44{
45 async fn on_blocking_task<C, F, R>(&self, c: C) -> EthResult<R>
47 where
48 C: FnOnce(Self) -> F,
49 F: Future<Output = EthResult<R>> + Send + 'static,
50 R: Send + 'static,
51 {
52 let (tx, rx) = oneshot::channel();
53 let this = self.clone();
54 let f = c(this);
55 self.inner.task_spawner.spawn_blocking(Box::pin(async move {
56 let res = f.await;
57 let _ = tx.send(res);
58 }));
59 rx.await.map_err(|_| EthApiError::InternalEthError)?
60 }
61
62 pub async fn balance_changes_in_block(
64 &self,
65 block_id: BlockId,
66 ) -> EthResult<HashMap<Address, U256>> {
67 self.on_blocking_task(|this| async move { this.try_balance_changes_in_block(block_id) })
68 .await
69 }
70
71 fn try_balance_changes_in_block(&self, block_id: BlockId) -> EthResult<HashMap<Address, U256>> {
72 let Some(block_number) = self.provider().block_number_for_id(block_id)? else {
73 return Err(EthApiError::HeaderNotFound(block_id))
74 };
75
76 let state = self.provider().state_by_block_id(block_id)?;
77 let accounts_before = self.provider().account_block_changeset(block_number)?;
78 let hash_map = accounts_before.iter().try_fold(
79 HashMap::default(),
80 |mut hash_map, account_before| -> RethResult<_> {
81 let current_balance = state.account_balance(&account_before.address)?;
82 let prev_balance = account_before.info.map(|info| info.balance);
83 if current_balance != prev_balance {
84 hash_map.insert(account_before.address, current_balance.unwrap_or_default());
85 }
86 Ok(hash_map)
87 },
88 )?;
89 Ok(hash_map)
90 }
91}
92
93#[async_trait]
94impl<Provider> RethApiServer for RethApi<Provider>
95where
96 Provider: BlockReaderIdExt
97 + ChangeSetReader
98 + StateProviderFactory
99 + CanonStateSubscriptions
100 + 'static,
101{
102 async fn reth_get_balance_changes_in_block(
104 &self,
105 block_id: BlockId,
106 ) -> RpcResult<HashMap<Address, U256>> {
107 Ok(Self::balance_changes_in_block(self, block_id).await?)
108 }
109
110 async fn reth_subscribe_chain_notifications(
112 &self,
113 pending: PendingSubscriptionSink,
114 ) -> jsonrpsee::core::SubscriptionResult {
115 let sink = pending.accept().await?;
116 let stream = self.provider().canonical_state_stream();
117 self.inner.task_spawner.spawn(Box::pin(async move {
118 let _ = pipe_from_stream(sink, stream).await;
119 }));
120
121 Ok(())
122 }
123}
124
125async fn pipe_from_stream<N: NodePrimitives>(
127 sink: SubscriptionSink,
128 mut stream: CanonStateNotificationStream<N>,
129) -> Result<(), ErrorObject<'static>> {
130 loop {
131 tokio::select! {
132 _ = sink.closed() => {
133 break Ok(())
135 }
136 maybe_item = stream.next() => {
137 let item = match maybe_item {
138 Some(item) => item,
139 None => {
140 break Ok(())
142 },
143 };
144 let msg = SubscriptionMessage::new(sink.method_name(), sink.subscription_id(), &item)
145 .map_err(|e| internal_rpc_err(e.to_string()))?;
146
147 if sink.send(msg).await.is_err() {
148 break Ok(());
149 }
150 }
151 }
152 }
153}
154
155impl<Provider> std::fmt::Debug for RethApi<Provider> {
156 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
157 f.debug_struct("RethApi").finish_non_exhaustive()
158 }
159}
160
161impl<Provider> Clone for RethApi<Provider> {
162 fn clone(&self) -> Self {
163 Self { inner: Arc::clone(&self.inner) }
164 }
165}
166
167struct RethApiInner<Provider> {
168 provider: Provider,
170 task_spawner: Box<dyn TaskSpawner>,
172}