reth_bench/bench/
persistence_waiter.rs1use alloy_eips::BlockNumHash;
11use alloy_network::Ethereum;
12use alloy_provider::{Provider, RootProvider};
13use alloy_pubsub::SubscriptionStream;
14use alloy_rpc_client::RpcClient;
15use alloy_transport_ws::WsConnect;
16use eyre::Context;
17use futures::StreamExt;
18use std::time::Duration;
19use tracing::{debug, info};
20
21const DEFAULT_WS_RPC_PORT: u16 = 8546;
23use url::Url;
24
25pub(crate) fn derive_ws_rpc_url(
36 ws_rpc_url: Option<&str>,
37 engine_rpc_url: &str,
38) -> eyre::Result<Url> {
39 if let Some(ws_url) = ws_rpc_url {
40 let parsed: Url = ws_url
41 .parse()
42 .wrap_err_with(|| format!("Failed to parse WebSocket RPC URL: {ws_url}"))?;
43 info!(target: "reth-bench", ws_url = %parsed, "Using provided WebSocket RPC URL");
44 Ok(parsed)
45 } else {
46 let derived = engine_url_to_ws_url(engine_rpc_url)?;
47 debug!(
48 target: "reth-bench",
49 engine_url = %engine_rpc_url,
50 %derived,
51 "Derived WebSocket RPC URL from engine RPC URL"
52 );
53 Ok(derived)
54 }
55}
56
57fn engine_url_to_ws_url(engine_url: &str) -> eyre::Result<Url> {
68 let url: Url = engine_url
69 .parse()
70 .wrap_err_with(|| format!("Failed to parse engine RPC URL: {engine_url}"))?;
71
72 let mut ws_url = url.clone();
73
74 match ws_url.scheme() {
75 "http" => ws_url
76 .set_scheme("ws")
77 .map_err(|_| eyre::eyre!("Failed to set WS scheme for URL: {url}"))?,
78 "https" => ws_url
79 .set_scheme("wss")
80 .map_err(|_| eyre::eyre!("Failed to set WSS scheme for URL: {url}"))?,
81 "ws" | "wss" => {}
82 scheme => {
83 return Err(eyre::eyre!(
84 "Unsupported URL scheme '{scheme}' for URL: {url}. Expected http, https, ws, or wss."
85 ))
86 }
87 }
88
89 ws_url
90 .set_port(Some(DEFAULT_WS_RPC_PORT))
91 .map_err(|_| eyre::eyre!("Failed to set port for URL: {url}"))?;
92
93 Ok(ws_url)
94}
95
96async fn wait_for_persistence(
102 stream: &mut SubscriptionStream<BlockNumHash>,
103 target: u64,
104 last_persisted: &mut u64,
105 timeout: Duration,
106) -> eyre::Result<()> {
107 tokio::time::timeout(timeout, async {
108 while *last_persisted < target {
109 match stream.next().await {
110 Some(persisted) => {
111 *last_persisted = persisted.number;
112 debug!(
113 target: "reth-bench",
114 persisted_block = ?last_persisted,
115 "Received persistence notification"
116 );
117 }
118 None => {
119 return Err(eyre::eyre!("Persistence subscription closed unexpectedly"));
120 }
121 }
122 }
123 Ok(())
124 })
125 .await
126 .map_err(|_| {
127 eyre::eyre!(
128 "Persistence timeout: target block {} not persisted within {:?}. Last persisted: {}",
129 target,
130 timeout,
131 last_persisted
132 )
133 })?
134}
135
136pub(crate) struct PersistenceSubscription {
139 _provider: RootProvider<Ethereum>,
140 stream: SubscriptionStream<BlockNumHash>,
141}
142
143impl PersistenceSubscription {
144 const fn new(
145 provider: RootProvider<Ethereum>,
146 stream: SubscriptionStream<BlockNumHash>,
147 ) -> Self {
148 Self { _provider: provider, stream }
149 }
150
151 const fn stream_mut(&mut self) -> &mut SubscriptionStream<BlockNumHash> {
152 &mut self.stream
153 }
154}
155
156pub(crate) async fn setup_persistence_subscription(
162 ws_url: Url,
163 persistence_timeout: Duration,
164) -> eyre::Result<PersistenceSubscription> {
165 info!(target: "reth-bench", "Connecting to WebSocket at {} for persistence subscription", ws_url);
166
167 let ws_connect =
168 WsConnect::new(ws_url.to_string()).with_keepalive_interval(persistence_timeout);
169 let client = RpcClient::connect_pubsub(ws_connect)
170 .await
171 .wrap_err("Failed to connect to WebSocket RPC endpoint")?;
172 let provider: RootProvider<Ethereum> = RootProvider::new(client);
173
174 let subscription = provider
175 .subscribe_to::<BlockNumHash>("reth_subscribePersistedBlock")
176 .await
177 .wrap_err("Failed to subscribe to persistence notifications")?;
178
179 info!(target: "reth-bench", "Subscribed to persistence notifications");
180
181 Ok(PersistenceSubscription::new(provider, subscription.into_stream()))
182}
183
184pub(crate) struct PersistenceWaiter {
192 wait_time: Option<Duration>,
193 subscription: Option<PersistenceSubscription>,
194 blocks_sent: u64,
195 last_persisted: u64,
196 threshold: u64,
197 timeout: Duration,
198}
199
200impl PersistenceWaiter {
201 pub(crate) const fn with_duration(wait_time: Duration) -> Self {
202 Self {
203 wait_time: Some(wait_time),
204 subscription: None,
205 blocks_sent: 0,
206 last_persisted: 0,
207 threshold: 0,
208 timeout: Duration::ZERO,
209 }
210 }
211
212 pub(crate) const fn with_subscription(
213 subscription: PersistenceSubscription,
214 threshold: u64,
215 timeout: Duration,
216 ) -> Self {
217 Self {
218 wait_time: None,
219 subscription: Some(subscription),
220 blocks_sent: 0,
221 last_persisted: 0,
222 threshold,
223 timeout,
224 }
225 }
226
227 pub(crate) const fn with_duration_and_subscription(
232 wait_time: Duration,
233 subscription: PersistenceSubscription,
234 threshold: u64,
235 timeout: Duration,
236 ) -> Self {
237 Self {
238 wait_time: Some(wait_time),
239 subscription: Some(subscription),
240 blocks_sent: 0,
241 last_persisted: 0,
242 threshold,
243 timeout,
244 }
245 }
246
247 #[allow(clippy::manual_is_multiple_of)]
253 pub(crate) async fn on_block(&mut self, block_number: u64) -> eyre::Result<()> {
254 if let Some(wait_time) = self.wait_time {
256 tokio::time::sleep(wait_time).await;
257 }
258
259 let Some(ref mut subscription) = self.subscription else {
261 return Ok(());
262 };
263
264 self.blocks_sent += 1;
265
266 if self.blocks_sent % (self.threshold + 1) == 0 {
267 debug!(
268 target: "reth-bench",
269 target_block = ?block_number,
270 last_persisted = self.last_persisted,
271 blocks_sent = self.blocks_sent,
272 "Waiting for persistence"
273 );
274
275 wait_for_persistence(
276 subscription.stream_mut(),
277 block_number,
278 &mut self.last_persisted,
279 self.timeout,
280 )
281 .await?;
282
283 debug!(
284 target: "reth-bench",
285 persisted = self.last_persisted,
286 "Persistence caught up"
287 );
288 }
289
290 Ok(())
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use std::time::Instant;
298
299 #[test]
300 fn test_engine_url_to_ws_url() {
301 let result = engine_url_to_ws_url("http://localhost:8551").unwrap();
303 assert_eq!(result.as_str(), "ws://localhost:8546/");
304
305 let result = engine_url_to_ws_url("https://localhost:8551").unwrap();
307 assert_eq!(result.as_str(), "wss://localhost:8546/");
308
309 let result = engine_url_to_ws_url("http://localhost:9551").unwrap();
311 assert_eq!(result.port(), Some(8546));
312
313 let result = engine_url_to_ws_url("ws://localhost:8546").unwrap();
315 assert_eq!(result.scheme(), "ws");
316
317 assert!(engine_url_to_ws_url("ftp://localhost:8551").is_err());
319 assert!(engine_url_to_ws_url("not a valid url").is_err());
320 }
321
322 #[tokio::test]
323 async fn test_waiter_with_duration() {
324 let mut waiter = PersistenceWaiter::with_duration(Duration::from_millis(1));
325
326 let start = Instant::now();
327 waiter.on_block(1).await.unwrap();
328 waiter.on_block(2).await.unwrap();
329 waiter.on_block(3).await.unwrap();
330
331 assert!(start.elapsed() >= Duration::from_millis(3));
333 }
334}