reth_rpc_layer/
compression_layer.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
use jsonrpsee_http_client::{HttpBody, HttpRequest, HttpResponse};
use std::{
    future::Future,
    pin::Pin,
    task::{Context, Poll},
};
use tower::{Layer, Service};
use tower_http::compression::{Compression, CompressionLayer as TowerCompressionLayer};

/// This layer is a wrapper around [`tower_http::compression::CompressionLayer`] that integrates
/// with jsonrpsee's HTTP types. It automatically compresses responses based on the client's
/// Accept-Encoding header.
#[allow(missing_debug_implementations)]
#[derive(Clone)]
pub struct CompressionLayer {
    inner_layer: TowerCompressionLayer,
}

impl CompressionLayer {
    /// Creates a new compression layer with zstd, gzip, brotli and deflate enabled.
    pub fn new() -> Self {
        Self {
            inner_layer: TowerCompressionLayer::new().gzip(true).br(true).deflate(true).zstd(true),
        }
    }
}

impl Default for CompressionLayer {
    /// Creates a new compression layer with default settings.
    /// See [`CompressionLayer::new`] for details.
    fn default() -> Self {
        Self::new()
    }
}

impl<S> Layer<S> for CompressionLayer {
    type Service = CompressionService<S>;

    fn layer(&self, inner: S) -> Self::Service {
        CompressionService { compression: self.inner_layer.layer(inner) }
    }
}

/// Service that performs response compression.
///
/// Created by [`CompressionLayer`].
#[allow(missing_debug_implementations)]
#[derive(Clone)]
pub struct CompressionService<S> {
    compression: Compression<S>,
}

impl<S> Service<HttpRequest> for CompressionService<S>
where
    S: Service<HttpRequest, Response = HttpResponse>,
    S::Future: Send + 'static,
{
    type Response = HttpResponse;
    type Error = S::Error;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.compression.poll_ready(cx)
    }

    fn call(&mut self, req: HttpRequest) -> Self::Future {
        let fut = self.compression.call(req);

        Box::pin(async move {
            let resp = fut.await?;
            let (parts, compressed_body) = resp.into_parts();
            let http_body = HttpBody::new(compressed_body);

            Ok(Self::Response::from_parts(parts, http_body))
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING};
    use http_body_util::BodyExt;
    use jsonrpsee_http_client::{HttpRequest, HttpResponse};
    use std::{convert::Infallible, future::ready};

    const TEST_DATA: &str = "compress test data ";
    const REPEAT_COUNT: usize = 1000;

    #[derive(Clone)]
    struct MockRequestService;

    impl Service<HttpRequest> for MockRequestService {
        type Response = HttpResponse;
        type Error = Infallible;
        type Future = std::future::Ready<Result<Self::Response, Self::Error>>;

        fn poll_ready(
            &mut self,
            _: &mut std::task::Context<'_>,
        ) -> std::task::Poll<Result<(), Self::Error>> {
            std::task::Poll::Ready(Ok(()))
        }

        fn call(&mut self, _: HttpRequest) -> Self::Future {
            let body = HttpBody::from(TEST_DATA.repeat(REPEAT_COUNT));
            let response = HttpResponse::builder().body(body).unwrap();
            ready(Ok(response))
        }
    }

    fn setup_compression_service(
    ) -> impl Service<HttpRequest, Response = HttpResponse, Error = Infallible> {
        CompressionLayer::new().layer(MockRequestService)
    }

    async fn get_response_size(response: HttpResponse) -> usize {
        // Get the total size of the response body
        response.into_body().collect().await.unwrap().to_bytes().len()
    }

    #[tokio::test]
    async fn test_gzip_compression() {
        let mut service = setup_compression_service();
        let request =
            HttpRequest::builder().header(ACCEPT_ENCODING, "gzip").body(HttpBody::empty()).unwrap();

        let uncompressed_len = TEST_DATA.repeat(REPEAT_COUNT).len();

        // Make the request
        let response = service.call(request).await.unwrap();

        // Verify the response has gzip content-encoding
        assert_eq!(
            response.headers().get(CONTENT_ENCODING).unwrap(),
            "gzip",
            "Response should be gzip encoded"
        );

        // Verify the response body is actually compressed (should be smaller than original)
        let compressed_size = get_response_size(response).await;
        assert!(
            compressed_size < uncompressed_len,
            "Compressed size ({compressed_size}) should be smaller than original size ({uncompressed_len})"
        );
    }

    #[tokio::test]
    async fn test_no_compression_when_not_requested() {
        // Create a service with compression
        let mut service = setup_compression_service();
        let request = HttpRequest::builder().body(HttpBody::empty()).unwrap();

        let response = service.call(request).await.unwrap();
        assert!(
            response.headers().get(CONTENT_ENCODING).is_none(),
            "Response should not be compressed when not requested"
        );

        let uncompressed_len = TEST_DATA.repeat(REPEAT_COUNT).len();

        // Verify the response body matches the original size
        let response_size = get_response_size(response).await;
        assert!(
            response_size == uncompressed_len,
            "Response size ({response_size}) should equal original size ({uncompressed_len})"
        );
    }
}