1use std::sync::Arc;
5
6use alloy_json_rpc::{RequestPacket, ResponsePacket};
7use alloy_pubsub::{PubSubConnect, PubSubFrontend};
8use alloy_rpc_types_engine::{Claims, JwtSecret};
9use alloy_transport::{
10 utils::guess_local_url, Authorization, BoxTransport, TransportConnect, TransportError,
11 TransportErrorKind, TransportFut,
12};
13use alloy_transport_http::{reqwest::Url, Http, ReqwestTransport};
14use alloy_transport_ipc::IpcConnect;
15use alloy_transport_ws::WsConnect;
16use futures::FutureExt;
17use reqwest::header::HeaderValue;
18use std::task::{Context, Poll};
19use tokio::sync::RwLock;
20use tower::Service;
21
22#[derive(Clone, Debug)]
25pub enum InnerTransport {
26 Http(ReqwestTransport),
28 Ws(PubSubFrontend),
30 Ipc(PubSubFrontend),
32}
33
34impl InnerTransport {
35 async fn connect(
38 url: Url,
39 jwt: JwtSecret,
40 ) -> Result<(Self, Claims), AuthenticatedTransportError> {
41 match url.scheme() {
42 "http" | "https" => Self::connect_http(url, jwt),
43 "ws" | "wss" => Self::connect_ws(url, jwt).await,
44 "file" => Ok((Self::connect_ipc(url).await?, Claims::default())),
45 _ => Err(AuthenticatedTransportError::BadScheme(url.scheme().to_string())),
46 }
47 }
48
49 fn connect_http(
52 url: Url,
53 jwt: JwtSecret,
54 ) -> Result<(Self, Claims), AuthenticatedTransportError> {
55 let mut client_builder = reqwest::Client::builder();
56 let mut headers = reqwest::header::HeaderMap::new();
57
58 let (auth, claims) =
60 build_auth(jwt).map_err(|e| AuthenticatedTransportError::InvalidJwt(e.to_string()))?;
61
62 let mut auth_value: HeaderValue =
63 HeaderValue::from_str(&auth.to_string()).expect("Header should be valid string");
64 auth_value.set_sensitive(true);
65
66 headers.insert(reqwest::header::AUTHORIZATION, auth_value);
67 client_builder = client_builder.default_headers(headers);
68
69 let client =
70 client_builder.build().map_err(AuthenticatedTransportError::HttpConstructionError)?;
71
72 let inner = Self::Http(Http::with_client(client, url));
73 Ok((inner, claims))
74 }
75
76 async fn connect_ws(
79 url: Url,
80 jwt: JwtSecret,
81 ) -> Result<(Self, Claims), AuthenticatedTransportError> {
82 let (auth, claims) =
84 build_auth(jwt).map_err(|e| AuthenticatedTransportError::InvalidJwt(e.to_string()))?;
85
86 let inner = WsConnect::new(url.clone())
87 .with_auth(auth)
88 .into_service()
89 .await
90 .map(Self::Ws)
91 .map_err(|e| AuthenticatedTransportError::TransportError(e, url.to_string()))?;
92
93 Ok((inner, claims))
94 }
95
96 async fn connect_ipc(url: Url) -> Result<Self, AuthenticatedTransportError> {
99 IpcConnect::new(url.to_string())
101 .into_service()
102 .await
103 .map(InnerTransport::Ipc)
104 .map_err(|e| AuthenticatedTransportError::TransportError(e, url.to_string()))
105 }
106}
107
108#[derive(Debug, Clone)]
110pub struct AuthenticatedTransport {
111 inner_and_claims: Arc<RwLock<(InnerTransport, Claims)>>,
116 jwt: JwtSecret,
118 url: Url,
120}
121
122#[derive(Debug, thiserror::Error)]
124pub enum AuthenticatedTransportError {
125 #[error("The URL is invalid")]
127 InvalidUrl,
128 #[error("Failed to lock transport")]
130 LockFailed,
131 #[error("The JWT is invalid: {0}")]
133 InvalidJwt(String),
134 #[error("The transport failed to connect to {1}, transport error: {0}")]
136 TransportError(TransportError, String),
137 #[error("The http client could not be built")]
139 HttpConstructionError(reqwest::Error),
140 #[error("The URL scheme is invalid: {0}")]
142 BadScheme(String),
143}
144
145impl AuthenticatedTransport {
146 pub async fn connect(url: Url, jwt: JwtSecret) -> Result<Self, AuthenticatedTransportError> {
148 let (inner, claims) = InnerTransport::connect(url.clone(), jwt).await?;
149 Ok(Self { inner_and_claims: Arc::new(RwLock::new((inner, claims))), jwt, url })
150 }
151
152 fn request(&self, req: RequestPacket) -> TransportFut<'static> {
158 let this = self.clone();
159
160 Box::pin(async move {
161 let mut inner_and_claims = this.inner_and_claims.write().await;
162
163 let mut shifted_claims = inner_and_claims.1;
165 shifted_claims.iat -= 30;
166
167 if !shifted_claims.is_within_time_window() {
169 let (new_inner, new_claims) =
170 InnerTransport::connect(this.url.clone(), this.jwt).await.map_err(|e| {
171 TransportError::Transport(TransportErrorKind::Custom(Box::new(e)))
172 })?;
173 *inner_and_claims = (new_inner, new_claims);
174 }
175
176 match inner_and_claims.0 {
177 InnerTransport::Http(ref mut http) => http.call(req),
178 InnerTransport::Ws(ref mut ws) => ws.call(req),
179 InnerTransport::Ipc(ref mut ipc) => ipc.call(req),
180 }
181 .await
182 })
183 }
184}
185
186fn build_auth(secret: JwtSecret) -> eyre::Result<(Authorization, Claims)> {
187 let claims = Claims::default();
190 let token = secret.encode(&claims)?;
191 let auth = Authorization::Bearer(token);
192
193 Ok((auth, claims))
194}
195
196#[derive(Clone, Debug)]
198pub struct AuthenticatedTransportConnect {
199 url: Url,
201 jwt: JwtSecret,
203}
204
205impl AuthenticatedTransportConnect {
206 pub const fn new(url: Url, jwt: JwtSecret) -> Self {
208 Self { url, jwt }
209 }
210}
211
212impl TransportConnect for AuthenticatedTransportConnect {
213 fn is_local(&self) -> bool {
214 guess_local_url(&self.url)
215 }
216
217 async fn get_transport(&self) -> Result<BoxTransport, TransportError> {
218 Ok(BoxTransport::new(
219 AuthenticatedTransport::connect(self.url.clone(), self.jwt)
220 .map(|res| match res {
221 Ok(transport) => Ok(transport),
222 Err(err) => {
223 Err(TransportError::Transport(TransportErrorKind::Custom(Box::new(err))))
224 }
225 })
226 .await?,
227 ))
228 }
229}
230
231impl tower::Service<RequestPacket> for AuthenticatedTransport {
232 type Response = ResponsePacket;
233 type Error = TransportError;
234 type Future = TransportFut<'static>;
235
236 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
237 Poll::Ready(Ok(()))
238 }
239
240 fn call(&mut self, req: RequestPacket) -> Self::Future {
241 self.request(req)
242 }
243}
244
245impl tower::Service<RequestPacket> for &AuthenticatedTransport {
246 type Response = ResponsePacket;
247 type Error = TransportError;
248 type Future = TransportFut<'static>;
249
250 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
251 Poll::Ready(Ok(()))
252 }
253
254 fn call(&mut self, req: RequestPacket) -> Self::Future {
255 self.request(req)
256 }
257}