reth_rpc_layer/
auth_layer.rs

1use super::AuthValidator;
2use jsonrpsee_http_client::{HttpRequest, HttpResponse};
3use pin_project::pin_project;
4use std::{
5    future::Future,
6    pin::Pin,
7    task::{Context, Poll},
8};
9use tower::{Layer, Service};
10
11/// This is an Http middleware layer that acts as an
12/// interceptor for `Authorization` headers. Incoming requests are dispatched to
13/// an inner [`AuthValidator`]. Invalid requests are blocked and the validator's error response is
14/// returned. Valid requests are instead dispatched to the next layer along the chain.
15///
16/// # How to integrate
17/// ```rust
18/// async fn build_layered_rpc_server() {
19///     use jsonrpsee::server::ServerBuilder;
20///     use reth_rpc_layer::{AuthLayer, JwtAuthValidator, JwtSecret};
21///     use std::net::SocketAddr;
22///
23///     const AUTH_PORT: u32 = 8551;
24///     const AUTH_ADDR: &str = "0.0.0.0";
25///     const AUTH_SECRET: &str =
26///         "f79ae8046bc11c9927afe911db7143c51a806c4a537cc08e0d37140b0192f430";
27///
28///     let addr = format!("{AUTH_ADDR}:{AUTH_PORT}");
29///     let secret = JwtSecret::from_hex(AUTH_SECRET).unwrap();
30///     let validator = JwtAuthValidator::new(secret);
31///     let layer = AuthLayer::new(validator);
32///     let middleware = tower::ServiceBuilder::default().layer(layer);
33///
34///     let _server = ServerBuilder::default()
35///         .set_http_middleware(middleware)
36///         .build(addr.parse::<SocketAddr>().unwrap())
37///         .await
38///         .unwrap();
39/// }
40/// ```
41#[allow(missing_debug_implementations)]
42pub struct AuthLayer<V> {
43    validator: V,
44}
45
46impl<V> AuthLayer<V> {
47    /// Creates an instance of [`AuthLayer`].
48    /// `validator` is a generic trait able to validate requests (see [`AuthValidator`]).
49    pub const fn new(validator: V) -> Self {
50        Self { validator }
51    }
52}
53
54impl<S, V> Layer<S> for AuthLayer<V>
55where
56    V: Clone,
57{
58    type Service = AuthService<S, V>;
59
60    fn layer(&self, inner: S) -> Self::Service {
61        AuthService { validator: self.validator.clone(), inner }
62    }
63}
64
65/// This type is the actual implementation of the middleware. It follows the [`Service`]
66/// specification to correctly proxy Http requests to its inner service after headers validation.
67#[derive(Clone, Debug)]
68pub struct AuthService<S, V> {
69    /// Performs auth validation logics
70    validator: V,
71    /// Recipient of authorized Http requests
72    inner: S,
73}
74
75impl<S, V> Service<HttpRequest> for AuthService<S, V>
76where
77    S: Service<HttpRequest, Response = HttpResponse>,
78    V: AuthValidator,
79    Self: Clone,
80{
81    type Response = HttpResponse;
82    type Error = S::Error;
83    type Future = ResponseFuture<S::Future>;
84
85    /// If we get polled it means that we dispatched an authorized Http request to the inner layer.
86    /// So we just poll the inner layer ourselves.
87    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
88        self.inner.poll_ready(cx)
89    }
90
91    /// This is the entrypoint of the service. We receive an Http request and check the validity of
92    /// the authorization header.
93    ///
94    /// Returns a future that wraps either:
95    /// - The inner service future for authorized requests
96    /// - An error Http response in case of authorization errors
97    fn call(&mut self, req: HttpRequest) -> Self::Future {
98        match self.validator.validate(req.headers()) {
99            Ok(_) => ResponseFuture::future(self.inner.call(req)),
100            Err(res) => ResponseFuture::invalid_auth(res),
101        }
102    }
103}
104
105/// A future representing the response of an RPC request
106#[pin_project]
107#[allow(missing_debug_implementations)]
108pub struct ResponseFuture<F> {
109    /// The kind of response future, error or pending
110    #[pin]
111    kind: Kind<F>,
112}
113
114impl<F> ResponseFuture<F> {
115    const fn future(future: F) -> Self {
116        Self { kind: Kind::Future { future } }
117    }
118
119    const fn invalid_auth(err_res: HttpResponse) -> Self {
120        Self { kind: Kind::Error { response: Some(err_res) } }
121    }
122}
123
124#[pin_project(project = KindProj)]
125enum Kind<F> {
126    Future {
127        #[pin]
128        future: F,
129    },
130    Error {
131        response: Option<HttpResponse>,
132    },
133}
134
135impl<F, E> Future for ResponseFuture<F>
136where
137    F: Future<Output = Result<HttpResponse, E>>,
138{
139    type Output = F::Output;
140
141    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
142        match self.project().kind.project() {
143            KindProj::Future { future } => future.poll(cx),
144            KindProj::Error { response } => {
145                let response = response.take().unwrap();
146                Poll::Ready(Ok(response))
147            }
148        }
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use crate::JwtAuthValidator;
156    use alloy_rpc_types_engine::{Claims, JwtError, JwtSecret};
157    use jsonrpsee::{
158        server::{RandomStringIdProvider, ServerBuilder, ServerHandle},
159        RpcModule,
160    };
161    use reqwest::{header, StatusCode};
162    use std::{
163        net::SocketAddr,
164        time::{SystemTime, UNIX_EPOCH},
165    };
166
167    const AUTH_PORT: u32 = 8551;
168    const AUTH_ADDR: &str = "0.0.0.0";
169    const SECRET: &str = "f79ae8046bc11c9927afe911db7143c51a806c4a537cc08e0d37140b0192f430";
170
171    #[tokio::test]
172    async fn test_jwt_layer() {
173        // We group all tests into one to avoid individual #[tokio::test]
174        // to concurrently spawn a server on the same port.
175        valid_jwt().await;
176        missing_jwt_error().await;
177        wrong_jwt_signature_error().await;
178        invalid_issuance_timestamp_error().await;
179        jwt_decode_error().await
180    }
181
182    async fn valid_jwt() {
183        let claims = Claims { iat: to_u64(SystemTime::now()), exp: Some(10000000000) };
184        let secret = JwtSecret::from_hex(SECRET).unwrap(); // Same secret as the server
185        let jwt = secret.encode(&claims).unwrap();
186        let (status, _) = send_request(Some(jwt)).await;
187        assert_eq!(status, StatusCode::OK);
188    }
189
190    async fn missing_jwt_error() {
191        let (status, body) = send_request(None).await;
192        let expected = JwtError::MissingOrInvalidAuthorizationHeader;
193        assert_eq!(status, StatusCode::UNAUTHORIZED);
194        assert_eq!(body, expected.to_string());
195    }
196
197    async fn wrong_jwt_signature_error() {
198        // This secret is different from the server. This will generate a
199        // different signature
200        let secret = JwtSecret::random();
201        let claims = Claims { iat: to_u64(SystemTime::now()), exp: Some(10000000000) };
202        let jwt = secret.encode(&claims).unwrap();
203
204        let (status, body) = send_request(Some(jwt)).await;
205        let expected = JwtError::InvalidSignature;
206        assert_eq!(status, StatusCode::UNAUTHORIZED);
207        assert_eq!(body, expected.to_string());
208    }
209
210    async fn invalid_issuance_timestamp_error() {
211        let secret = JwtSecret::from_hex(SECRET).unwrap(); // Same secret as the server
212
213        let iat = to_u64(SystemTime::now()) + 1000;
214        let claims = Claims { iat, exp: Some(10000000000) };
215        let jwt = secret.encode(&claims).unwrap();
216
217        let (status, body) = send_request(Some(jwt)).await;
218        let expected = JwtError::InvalidIssuanceTimestamp;
219        assert_eq!(status, StatusCode::UNAUTHORIZED);
220        assert_eq!(body, expected.to_string());
221    }
222
223    async fn jwt_decode_error() {
224        let jwt = "this jwt has serious encoding problems".to_string();
225        let (status, body) = send_request(Some(jwt)).await;
226        assert_eq!(status, StatusCode::UNAUTHORIZED);
227        assert_eq!(body, "JWT decoding error: InvalidToken".to_string());
228    }
229
230    async fn send_request(jwt: Option<String>) -> (StatusCode, String) {
231        let server = spawn_server().await;
232        let client =
233            reqwest::Client::builder().timeout(std::time::Duration::from_secs(1)).build().unwrap();
234
235        let body = r#"{"jsonrpc": "2.0", "method": "greet_melkor", "params": [], "id": 1}"#;
236        let response = client
237            .post(format!("http://{AUTH_ADDR}:{AUTH_PORT}"))
238            .bearer_auth(jwt.unwrap_or_default())
239            .body(body)
240            .header(header::CONTENT_TYPE, "application/json")
241            .send()
242            .await
243            .unwrap();
244        let status = response.status();
245        let body = response.text().await.unwrap();
246
247        server.stop().unwrap();
248        server.stopped().await;
249
250        (status, body)
251    }
252
253    /// Spawn a new RPC server equipped with a `JwtLayer` auth middleware.
254    async fn spawn_server() -> ServerHandle {
255        let secret = JwtSecret::from_hex(SECRET).unwrap();
256        let addr = format!("{AUTH_ADDR}:{AUTH_PORT}");
257        let validator = JwtAuthValidator::new(secret);
258        let layer = AuthLayer::new(validator);
259        let middleware = tower::ServiceBuilder::default().layer(layer);
260
261        // Create a layered server
262        let server = ServerBuilder::default()
263            .set_id_provider(RandomStringIdProvider::new(16))
264            .set_http_middleware(middleware)
265            .build(addr.parse::<SocketAddr>().unwrap())
266            .await
267            .unwrap();
268
269        // Create a mock rpc module
270        let mut module = RpcModule::new(());
271        module.register_method("greet_melkor", |_, _, _| "You are the dark lord").unwrap();
272
273        server.start(module)
274    }
275
276    fn to_u64(time: SystemTime) -> u64 {
277        time.duration_since(UNIX_EPOCH).unwrap().as_secs()
278    }
279}