use std::sync::Arc;
use alloy_json_rpc::{RequestPacket, ResponsePacket};
use alloy_pubsub::{PubSubConnect, PubSubFrontend};
use alloy_rpc_types_engine::{Claims, JwtSecret};
use alloy_transport::{
utils::guess_local_url, Authorization, Pbf, TransportConnect, TransportError,
TransportErrorKind, TransportFut,
};
use alloy_transport_http::{reqwest::Url, Http, ReqwestTransport};
use alloy_transport_ipc::IpcConnect;
use alloy_transport_ws::WsConnect;
use futures::FutureExt;
use reqwest::header::HeaderValue;
use std::task::{Context, Poll};
use tokio::sync::RwLock;
use tower::Service;
#[derive(Clone, Debug)]
pub enum InnerTransport {
Http(ReqwestTransport),
Ws(PubSubFrontend),
Ipc(PubSubFrontend),
}
impl InnerTransport {
async fn connect(
url: Url,
jwt: JwtSecret,
) -> Result<(Self, Claims), AuthenticatedTransportError> {
match url.scheme() {
"http" | "https" => Self::connect_http(url, jwt),
"ws" | "wss" => Self::connect_ws(url, jwt).await,
"file" => Ok((Self::connect_ipc(url).await?, Claims::default())),
_ => Err(AuthenticatedTransportError::BadScheme(url.scheme().to_string())),
}
}
fn connect_http(
url: Url,
jwt: JwtSecret,
) -> Result<(Self, Claims), AuthenticatedTransportError> {
let mut client_builder =
reqwest::Client::builder().tls_built_in_root_certs(url.scheme() == "https");
let mut headers = reqwest::header::HeaderMap::new();
let (auth, claims) =
build_auth(jwt).map_err(|e| AuthenticatedTransportError::InvalidJwt(e.to_string()))?;
let mut auth_value: HeaderValue =
HeaderValue::from_str(&auth.to_string()).expect("Header should be valid string");
auth_value.set_sensitive(true);
headers.insert(reqwest::header::AUTHORIZATION, auth_value);
client_builder = client_builder.default_headers(headers);
let client =
client_builder.build().map_err(AuthenticatedTransportError::HttpConstructionError)?;
let inner = Self::Http(Http::with_client(client, url));
Ok((inner, claims))
}
async fn connect_ws(
url: Url,
jwt: JwtSecret,
) -> Result<(Self, Claims), AuthenticatedTransportError> {
let (auth, claims) =
build_auth(jwt).map_err(|e| AuthenticatedTransportError::InvalidJwt(e.to_string()))?;
let inner = WsConnect::new(url.clone())
.with_auth(auth)
.into_service()
.await
.map(Self::Ws)
.map_err(|e| AuthenticatedTransportError::TransportError(e, url.to_string()))?;
Ok((inner, claims))
}
async fn connect_ipc(url: Url) -> Result<Self, AuthenticatedTransportError> {
IpcConnect::new(url.to_string())
.into_service()
.await
.map(InnerTransport::Ipc)
.map_err(|e| AuthenticatedTransportError::TransportError(e, url.to_string()))
}
}
#[derive(Debug, Clone)]
pub struct AuthenticatedTransport {
inner_and_claims: Arc<RwLock<(InnerTransport, Claims)>>,
jwt: JwtSecret,
url: Url,
}
#[derive(Debug, thiserror::Error)]
pub enum AuthenticatedTransportError {
#[error("The URL is invalid")]
InvalidUrl,
#[error("Failed to lock transport")]
LockFailed,
#[error("The JWT is invalid: {0}")]
InvalidJwt(String),
#[error("The transport failed to connect to {1}, transport error: {0}")]
TransportError(TransportError, String),
#[error("The http client could not be built")]
HttpConstructionError(reqwest::Error),
#[error("The URL scheme is invalid: {0}")]
BadScheme(String),
}
impl AuthenticatedTransport {
pub async fn connect(url: Url, jwt: JwtSecret) -> Result<Self, AuthenticatedTransportError> {
let (inner, claims) = InnerTransport::connect(url.clone(), jwt).await?;
Ok(Self { inner_and_claims: Arc::new(RwLock::new((inner, claims))), jwt, url })
}
fn request(&self, req: RequestPacket) -> TransportFut<'static> {
let this = self.clone();
Box::pin(async move {
let mut inner_and_claims = this.inner_and_claims.write().await;
let mut shifted_claims = inner_and_claims.1;
shifted_claims.iat -= 1;
if !shifted_claims.is_within_time_window() {
let (new_inner, new_claims) =
InnerTransport::connect(this.url.clone(), this.jwt).await.map_err(|e| {
TransportError::Transport(TransportErrorKind::Custom(Box::new(e)))
})?;
*inner_and_claims = (new_inner, new_claims);
}
match inner_and_claims.0 {
InnerTransport::Http(ref http) => {
let mut http = http;
http.call(req)
}
InnerTransport::Ws(ref ws) => {
let mut ws = ws;
ws.call(req)
}
InnerTransport::Ipc(ref ipc) => {
let mut ipc = ipc;
ipc.call(req)
}
}
.await
})
}
}
fn build_auth(secret: JwtSecret) -> eyre::Result<(Authorization, Claims)> {
let claims = Claims::default();
let token = secret.encode(&claims)?;
let auth = Authorization::Bearer(token);
Ok((auth, claims))
}
#[derive(Clone, Debug)]
pub struct AuthenticatedTransportConnect {
url: Url,
jwt: JwtSecret,
}
impl AuthenticatedTransportConnect {
pub const fn new(url: Url, jwt: JwtSecret) -> Self {
Self { url, jwt }
}
}
impl TransportConnect for AuthenticatedTransportConnect {
type Transport = AuthenticatedTransport;
fn is_local(&self) -> bool {
guess_local_url(&self.url)
}
fn get_transport<'a: 'b, 'b>(&'a self) -> Pbf<'b, Self::Transport, TransportError> {
AuthenticatedTransport::connect(self.url.clone(), self.jwt)
.map(|res| match res {
Ok(transport) => Ok(transport),
Err(err) => {
Err(TransportError::Transport(TransportErrorKind::Custom(Box::new(err))))
}
})
.boxed()
}
}
impl tower::Service<RequestPacket> for AuthenticatedTransport {
type Response = ResponsePacket;
type Error = TransportError;
type Future = TransportFut<'static>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: RequestPacket) -> Self::Future {
self.request(req)
}
}
impl tower::Service<RequestPacket> for &AuthenticatedTransport {
type Response = ResponsePacket;
type Error = TransportError;
type Future = TransportFut<'static>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: RequestPacket) -> Self::Future {
self.request(req)
}
}