reth_rpc_layer/
auth_layer.rs
1use super::AuthValidator;
2use jsonrpsee_http_client::{HttpRequest, HttpResponse};
3use pin_project::pin_project;
4use std::{
5 future::Future,
6 pin::Pin,
7 task::{Context, Poll},
8};
9use tower::{Layer, Service};
10
11#[allow(missing_debug_implementations)]
42pub struct AuthLayer<V> {
43 validator: V,
44}
45
46impl<V> AuthLayer<V> {
47 pub const fn new(validator: V) -> Self {
50 Self { validator }
51 }
52}
53
54impl<S, V> Layer<S> for AuthLayer<V>
55where
56 V: Clone,
57{
58 type Service = AuthService<S, V>;
59
60 fn layer(&self, inner: S) -> Self::Service {
61 AuthService { validator: self.validator.clone(), inner }
62 }
63}
64
65#[derive(Clone, Debug)]
68pub struct AuthService<S, V> {
69 validator: V,
71 inner: S,
73}
74
75impl<S, V> Service<HttpRequest> for AuthService<S, V>
76where
77 S: Service<HttpRequest, Response = HttpResponse>,
78 V: AuthValidator,
79 Self: Clone,
80{
81 type Response = HttpResponse;
82 type Error = S::Error;
83 type Future = ResponseFuture<S::Future>;
84
85 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
88 self.inner.poll_ready(cx)
89 }
90
91 fn call(&mut self, req: HttpRequest) -> Self::Future {
98 match self.validator.validate(req.headers()) {
99 Ok(_) => ResponseFuture::future(self.inner.call(req)),
100 Err(res) => ResponseFuture::invalid_auth(res),
101 }
102 }
103}
104
105#[pin_project]
107#[allow(missing_debug_implementations)]
108pub struct ResponseFuture<F> {
109 #[pin]
111 kind: Kind<F>,
112}
113
114impl<F> ResponseFuture<F> {
115 const fn future(future: F) -> Self {
116 Self { kind: Kind::Future { future } }
117 }
118
119 const fn invalid_auth(err_res: HttpResponse) -> Self {
120 Self { kind: Kind::Error { response: Some(err_res) } }
121 }
122}
123
124#[pin_project(project = KindProj)]
125enum Kind<F> {
126 Future {
127 #[pin]
128 future: F,
129 },
130 Error {
131 response: Option<HttpResponse>,
132 },
133}
134
135impl<F, E> Future for ResponseFuture<F>
136where
137 F: Future<Output = Result<HttpResponse, E>>,
138{
139 type Output = F::Output;
140
141 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
142 match self.project().kind.project() {
143 KindProj::Future { future } => future.poll(cx),
144 KindProj::Error { response } => {
145 let response = response.take().unwrap();
146 Poll::Ready(Ok(response))
147 }
148 }
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use crate::JwtAuthValidator;
156 use alloy_rpc_types_engine::{Claims, JwtError, JwtSecret};
157 use jsonrpsee::{
158 server::{RandomStringIdProvider, ServerBuilder, ServerHandle},
159 RpcModule,
160 };
161 use reqwest::{header, StatusCode};
162 use std::{
163 net::SocketAddr,
164 time::{SystemTime, UNIX_EPOCH},
165 };
166
167 const AUTH_PORT: u32 = 8551;
168 const AUTH_ADDR: &str = "0.0.0.0";
169 const SECRET: &str = "f79ae8046bc11c9927afe911db7143c51a806c4a537cc08e0d37140b0192f430";
170
171 #[tokio::test]
172 async fn test_jwt_layer() {
173 valid_jwt().await;
176 missing_jwt_error().await;
177 wrong_jwt_signature_error().await;
178 invalid_issuance_timestamp_error().await;
179 jwt_decode_error().await
180 }
181
182 async fn valid_jwt() {
183 let claims = Claims { iat: to_u64(SystemTime::now()), exp: Some(10000000000) };
184 let secret = JwtSecret::from_hex(SECRET).unwrap(); let jwt = secret.encode(&claims).unwrap();
186 let (status, _) = send_request(Some(jwt)).await;
187 assert_eq!(status, StatusCode::OK);
188 }
189
190 async fn missing_jwt_error() {
191 let (status, body) = send_request(None).await;
192 let expected = JwtError::MissingOrInvalidAuthorizationHeader;
193 assert_eq!(status, StatusCode::UNAUTHORIZED);
194 assert_eq!(body, expected.to_string());
195 }
196
197 async fn wrong_jwt_signature_error() {
198 let secret = JwtSecret::random();
201 let claims = Claims { iat: to_u64(SystemTime::now()), exp: Some(10000000000) };
202 let jwt = secret.encode(&claims).unwrap();
203
204 let (status, body) = send_request(Some(jwt)).await;
205 let expected = JwtError::InvalidSignature;
206 assert_eq!(status, StatusCode::UNAUTHORIZED);
207 assert_eq!(body, expected.to_string());
208 }
209
210 async fn invalid_issuance_timestamp_error() {
211 let secret = JwtSecret::from_hex(SECRET).unwrap(); let iat = to_u64(SystemTime::now()) + 1000;
214 let claims = Claims { iat, exp: Some(10000000000) };
215 let jwt = secret.encode(&claims).unwrap();
216
217 let (status, body) = send_request(Some(jwt)).await;
218 let expected = JwtError::InvalidIssuanceTimestamp;
219 assert_eq!(status, StatusCode::UNAUTHORIZED);
220 assert_eq!(body, expected.to_string());
221 }
222
223 async fn jwt_decode_error() {
224 let jwt = "this jwt has serious encoding problems".to_string();
225 let (status, body) = send_request(Some(jwt)).await;
226 assert_eq!(status, StatusCode::UNAUTHORIZED);
227 assert_eq!(body, "JWT decoding error: InvalidToken".to_string());
228 }
229
230 async fn send_request(jwt: Option<String>) -> (StatusCode, String) {
231 let server = spawn_server().await;
232 let client =
233 reqwest::Client::builder().timeout(std::time::Duration::from_secs(1)).build().unwrap();
234
235 let body = r#"{"jsonrpc": "2.0", "method": "greet_melkor", "params": [], "id": 1}"#;
236 let response = client
237 .post(format!("http://{AUTH_ADDR}:{AUTH_PORT}"))
238 .bearer_auth(jwt.unwrap_or_default())
239 .body(body)
240 .header(header::CONTENT_TYPE, "application/json")
241 .send()
242 .await
243 .unwrap();
244 let status = response.status();
245 let body = response.text().await.unwrap();
246
247 server.stop().unwrap();
248 server.stopped().await;
249
250 (status, body)
251 }
252
253 async fn spawn_server() -> ServerHandle {
255 let secret = JwtSecret::from_hex(SECRET).unwrap();
256 let addr = format!("{AUTH_ADDR}:{AUTH_PORT}");
257 let validator = JwtAuthValidator::new(secret);
258 let layer = AuthLayer::new(validator);
259 let middleware = tower::ServiceBuilder::default().layer(layer);
260
261 let server = ServerBuilder::default()
263 .set_id_provider(RandomStringIdProvider::new(16))
264 .set_http_middleware(middleware)
265 .build(addr.parse::<SocketAddr>().unwrap())
266 .await
267 .unwrap();
268
269 let mut module = RpcModule::new(());
271 module.register_method("greet_melkor", |_, _, _| "You are the dark lord").unwrap();
272
273 server.start(module)
274 }
275
276 fn to_u64(time: SystemTime) -> u64 {
277 time.duration_since(UNIX_EPOCH).unwrap().as_secs()
278 }
279}