reth_bench/
authenticated_transport.rs

1//! This contains an authenticated rpc transport that can be used to send engine API newPayload
2//! requests.
3
4use std::sync::Arc;
5
6use alloy_json_rpc::{RequestPacket, ResponsePacket};
7use alloy_pubsub::{PubSubConnect, PubSubFrontend};
8use alloy_rpc_types_engine::{Claims, JwtSecret};
9use alloy_transport::{
10    utils::guess_local_url, Authorization, BoxTransport, TransportConnect, TransportError,
11    TransportErrorKind, TransportFut,
12};
13use alloy_transport_http::{reqwest::Url, Http, ReqwestTransport};
14use alloy_transport_ipc::IpcConnect;
15use alloy_transport_ws::WsConnect;
16use futures::FutureExt;
17use reqwest::header::HeaderValue;
18use std::task::{Context, Poll};
19use tokio::sync::RwLock;
20use tower::Service;
21
22/// An enum representing the different transports that can be used to connect to a runtime.
23/// Only meant to be used internally by [`AuthenticatedTransport`].
24#[derive(Clone, Debug)]
25pub enum InnerTransport {
26    /// HTTP transport
27    Http(ReqwestTransport),
28    /// `WebSocket` transport
29    Ws(PubSubFrontend),
30    /// IPC transport
31    Ipc(PubSubFrontend),
32}
33
34impl InnerTransport {
35    /// Connects to a transport based on the given URL and JWT. Returns an [`InnerTransport`] and
36    /// the [`Claims`] generated from the jwt.
37    async fn connect(
38        url: Url,
39        jwt: JwtSecret,
40    ) -> Result<(Self, Claims), AuthenticatedTransportError> {
41        match url.scheme() {
42            "http" | "https" => Self::connect_http(url, jwt),
43            "ws" | "wss" => Self::connect_ws(url, jwt).await,
44            "file" => Ok((Self::connect_ipc(url).await?, Claims::default())),
45            _ => Err(AuthenticatedTransportError::BadScheme(url.scheme().to_string())),
46        }
47    }
48
49    /// Connects to an HTTP [`alloy_transport_http::Http`] transport. Returns an [`InnerTransport`]
50    /// and the [Claims] generated from the jwt.
51    fn connect_http(
52        url: Url,
53        jwt: JwtSecret,
54    ) -> Result<(Self, Claims), AuthenticatedTransportError> {
55        let mut client_builder =
56            reqwest::Client::builder().tls_built_in_root_certs(url.scheme() == "https");
57        let mut headers = reqwest::header::HeaderMap::new();
58
59        // Add the JWT to the headers if we can decode it.
60        let (auth, claims) =
61            build_auth(jwt).map_err(|e| AuthenticatedTransportError::InvalidJwt(e.to_string()))?;
62
63        let mut auth_value: HeaderValue =
64            HeaderValue::from_str(&auth.to_string()).expect("Header should be valid string");
65        auth_value.set_sensitive(true);
66
67        headers.insert(reqwest::header::AUTHORIZATION, auth_value);
68        client_builder = client_builder.default_headers(headers);
69
70        let client =
71            client_builder.build().map_err(AuthenticatedTransportError::HttpConstructionError)?;
72
73        let inner = Self::Http(Http::with_client(client, url));
74        Ok((inner, claims))
75    }
76
77    /// Connects to a `WebSocket` [`alloy_transport_ws::WsConnect`] transport. Returns an
78    /// [`InnerTransport`] and the [`Claims`] generated from the jwt.
79    async fn connect_ws(
80        url: Url,
81        jwt: JwtSecret,
82    ) -> Result<(Self, Claims), AuthenticatedTransportError> {
83        // Add the JWT to the headers if we can decode it.
84        let (auth, claims) =
85            build_auth(jwt).map_err(|e| AuthenticatedTransportError::InvalidJwt(e.to_string()))?;
86
87        let inner = WsConnect::new(url.clone())
88            .with_auth(auth)
89            .into_service()
90            .await
91            .map(Self::Ws)
92            .map_err(|e| AuthenticatedTransportError::TransportError(e, url.to_string()))?;
93
94        Ok((inner, claims))
95    }
96
97    /// Connects to an IPC [`alloy_transport_ipc::IpcConnect`] transport. Returns an
98    /// [`InnerTransport`]. Does not return any [`Claims`] because IPC does not require them.
99    async fn connect_ipc(url: Url) -> Result<Self, AuthenticatedTransportError> {
100        // IPC, even for engine, typically does not require auth because it's local
101        IpcConnect::new(url.to_string())
102            .into_service()
103            .await
104            .map(InnerTransport::Ipc)
105            .map_err(|e| AuthenticatedTransportError::TransportError(e, url.to_string()))
106    }
107}
108
109/// An authenticated transport that can be used to send requests that contain a jwt bearer token.
110#[derive(Debug, Clone)]
111pub struct AuthenticatedTransport {
112    /// The inner actual transport used.
113    ///
114    /// Also contains the current claims being used. This is used to determine whether or not we
115    /// should create another client.
116    inner_and_claims: Arc<RwLock<(InnerTransport, Claims)>>,
117    /// The current jwt is being used. This is so we can recreate claims.
118    jwt: JwtSecret,
119    /// The current URL is being used. This is so we can recreate the client if needed.
120    url: Url,
121}
122
123/// An error that can occur when creating an authenticated transport.
124#[derive(Debug, thiserror::Error)]
125pub enum AuthenticatedTransportError {
126    /// The URL is invalid.
127    #[error("The URL is invalid")]
128    InvalidUrl,
129    /// Failed to lock transport
130    #[error("Failed to lock transport")]
131    LockFailed,
132    /// The JWT is invalid.
133    #[error("The JWT is invalid: {0}")]
134    InvalidJwt(String),
135    /// The transport failed to connect.
136    #[error("The transport failed to connect to {1}, transport error: {0}")]
137    TransportError(TransportError, String),
138    /// The http client could not be built.
139    #[error("The http client could not be built")]
140    HttpConstructionError(reqwest::Error),
141    /// The scheme is invalid.
142    #[error("The URL scheme is invalid: {0}")]
143    BadScheme(String),
144}
145
146impl AuthenticatedTransport {
147    /// Create a new builder with the given URL.
148    pub async fn connect(url: Url, jwt: JwtSecret) -> Result<Self, AuthenticatedTransportError> {
149        let (inner, claims) = InnerTransport::connect(url.clone(), jwt).await?;
150        Ok(Self { inner_and_claims: Arc::new(RwLock::new((inner, claims))), jwt, url })
151    }
152
153    /// Sends a request using the underlying transport.
154    ///
155    /// For sending the actual request, this action is delegated down to the underlying transport
156    /// through Tower's [`tower::Service::call`]. See tower's [`tower::Service`] trait for more
157    /// information.
158    fn request(&self, req: RequestPacket) -> TransportFut<'static> {
159        let this = self.clone();
160
161        Box::pin(async move {
162            let mut inner_and_claims = this.inner_and_claims.write().await;
163
164            // shift the iat forward by one second so there is some buffer time
165            let mut shifted_claims = inner_and_claims.1;
166            shifted_claims.iat -= 1;
167
168            // if the claims are out of date, reset the inner transport
169            if !shifted_claims.is_within_time_window() {
170                let (new_inner, new_claims) =
171                    InnerTransport::connect(this.url.clone(), this.jwt).await.map_err(|e| {
172                        TransportError::Transport(TransportErrorKind::Custom(Box::new(e)))
173                    })?;
174                *inner_and_claims = (new_inner, new_claims);
175            }
176
177            match inner_and_claims.0 {
178                InnerTransport::Http(ref mut http) => http.call(req),
179                InnerTransport::Ws(ref mut ws) => ws.call(req),
180                InnerTransport::Ipc(ref mut ipc) => ipc.call(req),
181            }
182            .await
183        })
184    }
185}
186
187fn build_auth(secret: JwtSecret) -> eyre::Result<(Authorization, Claims)> {
188    // Generate claims (iat with current timestamp), this happens by default using the Default trait
189    // for Claims.
190    let claims = Claims::default();
191    let token = secret.encode(&claims)?;
192    let auth = Authorization::Bearer(token);
193
194    Ok((auth, claims))
195}
196
197/// This specifies how to connect to an authenticated transport.
198#[derive(Clone, Debug)]
199pub struct AuthenticatedTransportConnect {
200    /// The URL to connect to.
201    url: Url,
202    /// The JWT secret is used to authenticate the transport.
203    jwt: JwtSecret,
204}
205
206impl AuthenticatedTransportConnect {
207    /// Create a new builder with the given URL.
208    pub const fn new(url: Url, jwt: JwtSecret) -> Self {
209        Self { url, jwt }
210    }
211}
212
213impl TransportConnect for AuthenticatedTransportConnect {
214    fn is_local(&self) -> bool {
215        guess_local_url(&self.url)
216    }
217
218    async fn get_transport(&self) -> Result<BoxTransport, TransportError> {
219        Ok(BoxTransport::new(
220            AuthenticatedTransport::connect(self.url.clone(), self.jwt)
221                .map(|res| match res {
222                    Ok(transport) => Ok(transport),
223                    Err(err) => {
224                        Err(TransportError::Transport(TransportErrorKind::Custom(Box::new(err))))
225                    }
226                })
227                .await?,
228        ))
229    }
230}
231
232impl tower::Service<RequestPacket> for AuthenticatedTransport {
233    type Response = ResponsePacket;
234    type Error = TransportError;
235    type Future = TransportFut<'static>;
236
237    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
238        Poll::Ready(Ok(()))
239    }
240
241    fn call(&mut self, req: RequestPacket) -> Self::Future {
242        self.request(req)
243    }
244}
245
246impl tower::Service<RequestPacket> for &AuthenticatedTransport {
247    type Response = ResponsePacket;
248    type Error = TransportError;
249    type Future = TransportFut<'static>;
250
251    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
252        Poll::Ready(Ok(()))
253    }
254
255    fn call(&mut self, req: RequestPacket) -> Self::Future {
256        self.request(req)
257    }
258}