reth_rpc_layer/
compression_layer.rs
1use jsonrpsee_http_client::{HttpBody, HttpRequest, HttpResponse};
2use std::{
3 future::Future,
4 pin::Pin,
5 task::{Context, Poll},
6};
7use tower::{Layer, Service};
8use tower_http::compression::{Compression, CompressionLayer as TowerCompressionLayer};
9
10#[expect(missing_debug_implementations)]
14#[derive(Clone)]
15pub struct CompressionLayer {
16 inner_layer: TowerCompressionLayer,
17}
18
19impl CompressionLayer {
20 pub fn new() -> Self {
22 Self {
23 inner_layer: TowerCompressionLayer::new().gzip(true).br(true).deflate(true).zstd(true),
24 }
25 }
26}
27
28impl Default for CompressionLayer {
29 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl<S> Layer<S> for CompressionLayer {
37 type Service = CompressionService<S>;
38
39 fn layer(&self, inner: S) -> Self::Service {
40 CompressionService { compression: self.inner_layer.layer(inner) }
41 }
42}
43
44#[expect(missing_debug_implementations)]
48#[derive(Clone)]
49pub struct CompressionService<S> {
50 compression: Compression<S>,
51}
52
53impl<S> Service<HttpRequest> for CompressionService<S>
54where
55 S: Service<HttpRequest, Response = HttpResponse>,
56 S::Future: Send + 'static,
57{
58 type Response = HttpResponse;
59 type Error = S::Error;
60 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
61
62 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63 self.compression.poll_ready(cx)
64 }
65
66 fn call(&mut self, req: HttpRequest) -> Self::Future {
67 let fut = self.compression.call(req);
68
69 Box::pin(async move {
70 let resp = fut.await?;
71 let (parts, compressed_body) = resp.into_parts();
72 let http_body = HttpBody::new(compressed_body);
73
74 Ok(Self::Response::from_parts(parts, http_body))
75 })
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING};
83 use http_body_util::BodyExt;
84 use jsonrpsee_http_client::{HttpRequest, HttpResponse};
85 use std::{convert::Infallible, future::ready};
86
87 const TEST_DATA: &str = "compress test data ";
88 const REPEAT_COUNT: usize = 1000;
89
90 #[derive(Clone)]
91 struct MockRequestService;
92
93 impl Service<HttpRequest> for MockRequestService {
94 type Response = HttpResponse;
95 type Error = Infallible;
96 type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
97
98 fn poll_ready(
99 &mut self,
100 _: &mut std::task::Context<'_>,
101 ) -> std::task::Poll<Result<(), Self::Error>> {
102 std::task::Poll::Ready(Ok(()))
103 }
104
105 fn call(&mut self, _: HttpRequest) -> Self::Future {
106 let body = HttpBody::from(TEST_DATA.repeat(REPEAT_COUNT));
107 let response = HttpResponse::builder().body(body).unwrap();
108 ready(Ok(response))
109 }
110 }
111
112 fn setup_compression_service(
113 ) -> impl Service<HttpRequest, Response = HttpResponse, Error = Infallible> {
114 CompressionLayer::new().layer(MockRequestService)
115 }
116
117 async fn get_response_size(response: HttpResponse) -> usize {
118 response.into_body().collect().await.unwrap().to_bytes().len()
120 }
121
122 #[tokio::test]
123 async fn test_gzip_compression() {
124 let mut service = setup_compression_service();
125 let request =
126 HttpRequest::builder().header(ACCEPT_ENCODING, "gzip").body(HttpBody::empty()).unwrap();
127
128 let uncompressed_len = TEST_DATA.repeat(REPEAT_COUNT).len();
129
130 let response = service.call(request).await.unwrap();
132
133 assert_eq!(
135 response.headers().get(CONTENT_ENCODING).unwrap(),
136 "gzip",
137 "Response should be gzip encoded"
138 );
139
140 let compressed_size = get_response_size(response).await;
142 assert!(
143 compressed_size < uncompressed_len,
144 "Compressed size ({compressed_size}) should be smaller than original size ({uncompressed_len})"
145 );
146 }
147
148 #[tokio::test]
149 async fn test_no_compression_when_not_requested() {
150 let mut service = setup_compression_service();
152 let request = HttpRequest::builder().body(HttpBody::empty()).unwrap();
153
154 let response = service.call(request).await.unwrap();
155 assert!(
156 response.headers().get(CONTENT_ENCODING).is_none(),
157 "Response should not be compressed when not requested"
158 );
159
160 let uncompressed_len = TEST_DATA.repeat(REPEAT_COUNT).len();
161
162 let response_size = get_response_size(response).await;
164 assert!(
165 response_size == uncompressed_len,
166 "Response size ({response_size}) should equal original size ({uncompressed_len})"
167 );
168 }
169}