reth_rpc_layer/
auth_layer.rsuse super::AuthValidator;
use jsonrpsee_http_client::{HttpRequest, HttpResponse};
use pin_project::pin_project;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tower::{Layer, Service};
#[allow(missing_debug_implementations)]
pub struct AuthLayer<V> {
validator: V,
}
impl<V> AuthLayer<V> {
pub const fn new(validator: V) -> Self {
Self { validator }
}
}
impl<S, V> Layer<S> for AuthLayer<V>
where
V: Clone,
{
type Service = AuthService<S, V>;
fn layer(&self, inner: S) -> Self::Service {
AuthService { validator: self.validator.clone(), inner }
}
}
#[derive(Clone, Debug)]
pub struct AuthService<S, V> {
validator: V,
inner: S,
}
impl<S, V> Service<HttpRequest> for AuthService<S, V>
where
S: Service<HttpRequest, Response = HttpResponse>,
V: AuthValidator,
Self: Clone,
{
type Response = HttpResponse;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: HttpRequest) -> Self::Future {
match self.validator.validate(req.headers()) {
Ok(_) => ResponseFuture::future(self.inner.call(req)),
Err(res) => ResponseFuture::invalid_auth(res),
}
}
}
#[pin_project]
#[allow(missing_debug_implementations)]
pub struct ResponseFuture<F> {
#[pin]
kind: Kind<F>,
}
impl<F> ResponseFuture<F> {
const fn future(future: F) -> Self {
Self { kind: Kind::Future { future } }
}
const fn invalid_auth(err_res: HttpResponse) -> Self {
Self { kind: Kind::Error { response: Some(err_res) } }
}
}
#[pin_project(project = KindProj)]
enum Kind<F> {
Future {
#[pin]
future: F,
},
Error {
response: Option<HttpResponse>,
},
}
impl<F, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<HttpResponse, E>>,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().kind.project() {
KindProj::Future { future } => future.poll(cx),
KindProj::Error { response } => {
let response = response.take().unwrap();
Poll::Ready(Ok(response))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::JwtAuthValidator;
use alloy_rpc_types_engine::{Claims, JwtError, JwtSecret};
use jsonrpsee::{
server::{RandomStringIdProvider, ServerBuilder, ServerHandle},
RpcModule,
};
use reqwest::{header, StatusCode};
use std::{
net::SocketAddr,
time::{SystemTime, UNIX_EPOCH},
};
const AUTH_PORT: u32 = 8551;
const AUTH_ADDR: &str = "0.0.0.0";
const SECRET: &str = "f79ae8046bc11c9927afe911db7143c51a806c4a537cc08e0d37140b0192f430";
#[tokio::test]
async fn test_jwt_layer() {
valid_jwt().await;
missing_jwt_error().await;
wrong_jwt_signature_error().await;
invalid_issuance_timestamp_error().await;
jwt_decode_error().await
}
async fn valid_jwt() {
let claims = Claims { iat: to_u64(SystemTime::now()), exp: Some(10000000000) };
let secret = JwtSecret::from_hex(SECRET).unwrap(); let jwt = secret.encode(&claims).unwrap();
let (status, _) = send_request(Some(jwt)).await;
assert_eq!(status, StatusCode::OK);
}
async fn missing_jwt_error() {
let (status, body) = send_request(None).await;
let expected = JwtError::MissingOrInvalidAuthorizationHeader;
assert_eq!(status, StatusCode::UNAUTHORIZED);
assert_eq!(body, expected.to_string());
}
async fn wrong_jwt_signature_error() {
let secret = JwtSecret::random();
let claims = Claims { iat: to_u64(SystemTime::now()), exp: Some(10000000000) };
let jwt = secret.encode(&claims).unwrap();
let (status, body) = send_request(Some(jwt)).await;
let expected = JwtError::InvalidSignature;
assert_eq!(status, StatusCode::UNAUTHORIZED);
assert_eq!(body, expected.to_string());
}
async fn invalid_issuance_timestamp_error() {
let secret = JwtSecret::from_hex(SECRET).unwrap(); let iat = to_u64(SystemTime::now()) + 1000;
let claims = Claims { iat, exp: Some(10000000000) };
let jwt = secret.encode(&claims).unwrap();
let (status, body) = send_request(Some(jwt)).await;
let expected = JwtError::InvalidIssuanceTimestamp;
assert_eq!(status, StatusCode::UNAUTHORIZED);
assert_eq!(body, expected.to_string());
}
async fn jwt_decode_error() {
let jwt = "this jwt has serious encoding problems".to_string();
let (status, body) = send_request(Some(jwt)).await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
assert_eq!(body, "JWT decoding error: InvalidToken".to_string());
}
async fn send_request(jwt: Option<String>) -> (StatusCode, String) {
let server = spawn_server().await;
let client =
reqwest::Client::builder().timeout(std::time::Duration::from_secs(1)).build().unwrap();
let body = r#"{"jsonrpc": "2.0", "method": "greet_melkor", "params": [], "id": 1}"#;
let response = client
.post(format!("http://{AUTH_ADDR}:{AUTH_PORT}"))
.bearer_auth(jwt.unwrap_or_default())
.body(body)
.header(header::CONTENT_TYPE, "application/json")
.send()
.await
.unwrap();
let status = response.status();
let body = response.text().await.unwrap();
server.stop().unwrap();
server.stopped().await;
(status, body)
}
async fn spawn_server() -> ServerHandle {
let secret = JwtSecret::from_hex(SECRET).unwrap();
let addr = format!("{AUTH_ADDR}:{AUTH_PORT}");
let validator = JwtAuthValidator::new(secret);
let layer = AuthLayer::new(validator);
let middleware = tower::ServiceBuilder::default().layer(layer);
let server = ServerBuilder::default()
.set_id_provider(RandomStringIdProvider::new(16))
.set_http_middleware(middleware)
.build(addr.parse::<SocketAddr>().unwrap())
.await
.unwrap();
let mut module = RpcModule::new(());
module.register_method("greet_melkor", |_, _, _| "You are the dark lord").unwrap();
server.start(module)
}
fn to_u64(time: SystemTime) -> u64 {
time.duration_since(UNIX_EPOCH).unwrap().as_secs()
}
}