reth_rpc_layer/
compression_layer.rs

1use jsonrpsee_http_client::{HttpBody, HttpRequest, HttpResponse};
2use std::{
3    future::Future,
4    pin::Pin,
5    task::{Context, Poll},
6};
7use tower::{Layer, Service};
8use tower_http::compression::{Compression, CompressionLayer as TowerCompressionLayer};
9
10/// This layer is a wrapper around [`tower_http::compression::CompressionLayer`] that integrates
11/// with jsonrpsee's HTTP types. It automatically compresses responses based on the client's
12/// Accept-Encoding header.
13#[expect(missing_debug_implementations)]
14#[derive(Clone)]
15pub struct CompressionLayer {
16    inner_layer: TowerCompressionLayer,
17}
18
19impl CompressionLayer {
20    /// Creates a new compression layer with zstd, gzip, brotli and deflate enabled.
21    pub fn new() -> Self {
22        Self {
23            inner_layer: TowerCompressionLayer::new().gzip(true).br(true).deflate(true).zstd(true),
24        }
25    }
26}
27
28impl Default for CompressionLayer {
29    /// Creates a new compression layer with default settings.
30    /// See [`CompressionLayer::new`] for details.
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl<S> Layer<S> for CompressionLayer {
37    type Service = CompressionService<S>;
38
39    fn layer(&self, inner: S) -> Self::Service {
40        CompressionService { compression: self.inner_layer.layer(inner) }
41    }
42}
43
44/// Service that performs response compression.
45///
46/// Created by [`CompressionLayer`].
47#[expect(missing_debug_implementations)]
48#[derive(Clone)]
49pub struct CompressionService<S> {
50    compression: Compression<S>,
51}
52
53impl<S> Service<HttpRequest> for CompressionService<S>
54where
55    S: Service<HttpRequest, Response = HttpResponse>,
56    S::Future: Send + 'static,
57{
58    type Response = HttpResponse;
59    type Error = S::Error;
60    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
61
62    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
63        self.compression.poll_ready(cx)
64    }
65
66    fn call(&mut self, req: HttpRequest) -> Self::Future {
67        let fut = self.compression.call(req);
68
69        Box::pin(async move {
70            let resp = fut.await?;
71            let (parts, compressed_body) = resp.into_parts();
72            let http_body = HttpBody::new(compressed_body);
73
74            Ok(Self::Response::from_parts(parts, http_body))
75        })
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING};
83    use http_body_util::BodyExt;
84    use jsonrpsee_http_client::{HttpRequest, HttpResponse};
85    use std::{convert::Infallible, future::ready};
86
87    const TEST_DATA: &str = "compress test data ";
88    const REPEAT_COUNT: usize = 1000;
89
90    #[derive(Clone)]
91    struct MockRequestService;
92
93    impl Service<HttpRequest> for MockRequestService {
94        type Response = HttpResponse;
95        type Error = Infallible;
96        type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
97
98        fn poll_ready(
99            &mut self,
100            _: &mut std::task::Context<'_>,
101        ) -> std::task::Poll<Result<(), Self::Error>> {
102            std::task::Poll::Ready(Ok(()))
103        }
104
105        fn call(&mut self, _: HttpRequest) -> Self::Future {
106            let body = HttpBody::from(TEST_DATA.repeat(REPEAT_COUNT));
107            let response = HttpResponse::builder().body(body).unwrap();
108            ready(Ok(response))
109        }
110    }
111
112    fn setup_compression_service(
113    ) -> impl Service<HttpRequest, Response = HttpResponse, Error = Infallible> {
114        CompressionLayer::new().layer(MockRequestService)
115    }
116
117    async fn get_response_size(response: HttpResponse) -> usize {
118        // Get the total size of the response body
119        response.into_body().collect().await.unwrap().to_bytes().len()
120    }
121
122    #[tokio::test]
123    async fn test_gzip_compression() {
124        let mut service = setup_compression_service();
125        let request =
126            HttpRequest::builder().header(ACCEPT_ENCODING, "gzip").body(HttpBody::empty()).unwrap();
127
128        let uncompressed_len = TEST_DATA.repeat(REPEAT_COUNT).len();
129
130        // Make the request
131        let response = service.call(request).await.unwrap();
132
133        // Verify the response has gzip content-encoding
134        assert_eq!(
135            response.headers().get(CONTENT_ENCODING).unwrap(),
136            "gzip",
137            "Response should be gzip encoded"
138        );
139
140        // Verify the response body is actually compressed (should be smaller than original)
141        let compressed_size = get_response_size(response).await;
142        assert!(
143            compressed_size < uncompressed_len,
144            "Compressed size ({compressed_size}) should be smaller than original size ({uncompressed_len})"
145        );
146    }
147
148    #[tokio::test]
149    async fn test_no_compression_when_not_requested() {
150        // Create a service with compression
151        let mut service = setup_compression_service();
152        let request = HttpRequest::builder().body(HttpBody::empty()).unwrap();
153
154        let response = service.call(request).await.unwrap();
155        assert!(
156            response.headers().get(CONTENT_ENCODING).is_none(),
157            "Response should not be compressed when not requested"
158        );
159
160        let uncompressed_len = TEST_DATA.repeat(REPEAT_COUNT).len();
161
162        // Verify the response body matches the original size
163        let response_size = get_response_size(response).await;
164        assert!(
165            response_size == uncompressed_len,
166            "Response size ({response_size}) should equal original size ({uncompressed_len})"
167        );
168    }
169}