1use std::sync::Arc;
5
6use alloy_json_rpc::{RequestPacket, ResponsePacket};
7use alloy_pubsub::{PubSubConnect, PubSubFrontend};
8use alloy_rpc_types_engine::{Claims, JwtSecret};
9use alloy_transport::{
10 utils::guess_local_url, Authorization, BoxTransport, TransportConnect, TransportError,
11 TransportErrorKind, TransportFut,
12};
13use alloy_transport_http::{reqwest::Url, Http, ReqwestTransport};
14use alloy_transport_ipc::IpcConnect;
15use alloy_transport_ws::WsConnect;
16use futures::FutureExt;
17use reqwest::header::HeaderValue;
18use std::task::{Context, Poll};
19use tokio::sync::RwLock;
20use tower::Service;
21
22#[derive(Clone, Debug)]
25pub enum InnerTransport {
26 Http(ReqwestTransport),
28 Ws(PubSubFrontend),
30 Ipc(PubSubFrontend),
32}
33
34impl InnerTransport {
35 async fn connect(
38 url: Url,
39 jwt: JwtSecret,
40 ) -> Result<(Self, Claims), AuthenticatedTransportError> {
41 match url.scheme() {
42 "http" | "https" => Self::connect_http(url, jwt),
43 "ws" | "wss" => Self::connect_ws(url, jwt).await,
44 "file" => Ok((Self::connect_ipc(url).await?, Claims::default())),
45 _ => Err(AuthenticatedTransportError::BadScheme(url.scheme().to_string())),
46 }
47 }
48
49 fn connect_http(
52 url: Url,
53 jwt: JwtSecret,
54 ) -> Result<(Self, Claims), AuthenticatedTransportError> {
55 let mut client_builder =
56 reqwest::Client::builder().tls_built_in_root_certs(url.scheme() == "https");
57 let mut headers = reqwest::header::HeaderMap::new();
58
59 let (auth, claims) =
61 build_auth(jwt).map_err(|e| AuthenticatedTransportError::InvalidJwt(e.to_string()))?;
62
63 let mut auth_value: HeaderValue =
64 HeaderValue::from_str(&auth.to_string()).expect("Header should be valid string");
65 auth_value.set_sensitive(true);
66
67 headers.insert(reqwest::header::AUTHORIZATION, auth_value);
68 client_builder = client_builder.default_headers(headers);
69
70 let client =
71 client_builder.build().map_err(AuthenticatedTransportError::HttpConstructionError)?;
72
73 let inner = Self::Http(Http::with_client(client, url));
74 Ok((inner, claims))
75 }
76
77 async fn connect_ws(
80 url: Url,
81 jwt: JwtSecret,
82 ) -> Result<(Self, Claims), AuthenticatedTransportError> {
83 let (auth, claims) =
85 build_auth(jwt).map_err(|e| AuthenticatedTransportError::InvalidJwt(e.to_string()))?;
86
87 let inner = WsConnect::new(url.clone())
88 .with_auth(auth)
89 .into_service()
90 .await
91 .map(Self::Ws)
92 .map_err(|e| AuthenticatedTransportError::TransportError(e, url.to_string()))?;
93
94 Ok((inner, claims))
95 }
96
97 async fn connect_ipc(url: Url) -> Result<Self, AuthenticatedTransportError> {
100 IpcConnect::new(url.to_string())
102 .into_service()
103 .await
104 .map(InnerTransport::Ipc)
105 .map_err(|e| AuthenticatedTransportError::TransportError(e, url.to_string()))
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct AuthenticatedTransport {
112 inner_and_claims: Arc<RwLock<(InnerTransport, Claims)>>,
117 jwt: JwtSecret,
119 url: Url,
121}
122
123#[derive(Debug, thiserror::Error)]
125pub enum AuthenticatedTransportError {
126 #[error("The URL is invalid")]
128 InvalidUrl,
129 #[error("Failed to lock transport")]
131 LockFailed,
132 #[error("The JWT is invalid: {0}")]
134 InvalidJwt(String),
135 #[error("The transport failed to connect to {1}, transport error: {0}")]
137 TransportError(TransportError, String),
138 #[error("The http client could not be built")]
140 HttpConstructionError(reqwest::Error),
141 #[error("The URL scheme is invalid: {0}")]
143 BadScheme(String),
144}
145
146impl AuthenticatedTransport {
147 pub async fn connect(url: Url, jwt: JwtSecret) -> Result<Self, AuthenticatedTransportError> {
149 let (inner, claims) = InnerTransport::connect(url.clone(), jwt).await?;
150 Ok(Self { inner_and_claims: Arc::new(RwLock::new((inner, claims))), jwt, url })
151 }
152
153 fn request(&self, req: RequestPacket) -> TransportFut<'static> {
159 let this = self.clone();
160
161 Box::pin(async move {
162 let mut inner_and_claims = this.inner_and_claims.write().await;
163
164 let mut shifted_claims = inner_and_claims.1;
166 shifted_claims.iat -= 1;
167
168 if !shifted_claims.is_within_time_window() {
170 let (new_inner, new_claims) =
171 InnerTransport::connect(this.url.clone(), this.jwt).await.map_err(|e| {
172 TransportError::Transport(TransportErrorKind::Custom(Box::new(e)))
173 })?;
174 *inner_and_claims = (new_inner, new_claims);
175 }
176
177 match inner_and_claims.0 {
178 InnerTransport::Http(ref mut http) => http.call(req),
179 InnerTransport::Ws(ref mut ws) => ws.call(req),
180 InnerTransport::Ipc(ref mut ipc) => ipc.call(req),
181 }
182 .await
183 })
184 }
185}
186
187fn build_auth(secret: JwtSecret) -> eyre::Result<(Authorization, Claims)> {
188 let claims = Claims::default();
191 let token = secret.encode(&claims)?;
192 let auth = Authorization::Bearer(token);
193
194 Ok((auth, claims))
195}
196
197#[derive(Clone, Debug)]
199pub struct AuthenticatedTransportConnect {
200 url: Url,
202 jwt: JwtSecret,
204}
205
206impl AuthenticatedTransportConnect {
207 pub const fn new(url: Url, jwt: JwtSecret) -> Self {
209 Self { url, jwt }
210 }
211}
212
213impl TransportConnect for AuthenticatedTransportConnect {
214 fn is_local(&self) -> bool {
215 guess_local_url(&self.url)
216 }
217
218 async fn get_transport(&self) -> Result<BoxTransport, TransportError> {
219 Ok(BoxTransport::new(
220 AuthenticatedTransport::connect(self.url.clone(), self.jwt)
221 .map(|res| match res {
222 Ok(transport) => Ok(transport),
223 Err(err) => {
224 Err(TransportError::Transport(TransportErrorKind::Custom(Box::new(err))))
225 }
226 })
227 .await?,
228 ))
229 }
230}
231
232impl tower::Service<RequestPacket> for AuthenticatedTransport {
233 type Response = ResponsePacket;
234 type Error = TransportError;
235 type Future = TransportFut<'static>;
236
237 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
238 Poll::Ready(Ok(()))
239 }
240
241 fn call(&mut self, req: RequestPacket) -> Self::Future {
242 self.request(req)
243 }
244}
245
246impl tower::Service<RequestPacket> for &AuthenticatedTransport {
247 type Response = ResponsePacket;
248 type Error = TransportError;
249 type Future = TransportFut<'static>;
250
251 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
252 Poll::Ready(Ok(()))
253 }
254
255 fn call(&mut self, req: RequestPacket) -> Self::Future {
256 self.request(req)
257 }
258}