1use crate::Metrics;
5use futures::Stream;
6use metrics::Counter;
7use reth_primitives_traits::InMemorySize;
8use std::{
9 pin::Pin,
10 sync::{
11 atomic::{AtomicUsize, Ordering},
12 Arc,
13 },
14 task::{ready, Context, Poll},
15};
16use tokio::sync::mpsc::{
17 self,
18 error::{SendError, TryRecvError, TrySendError},
19};
20use tokio_util::sync::{PollSendError, PollSender};
21
22pub fn metered_unbounded_channel<T>(
24 scope: &'static str,
25) -> (UnboundedMeteredSender<T>, UnboundedMeteredReceiver<T>) {
26 let (tx, rx) = mpsc::unbounded_channel();
27 (UnboundedMeteredSender::new(tx, scope), UnboundedMeteredReceiver::new(rx, scope))
28}
29
30pub fn metered_channel<T>(
33 buffer: usize,
34 scope: &'static str,
35) -> (MeteredSender<T>, MeteredReceiver<T>) {
36 let (tx, rx) = mpsc::channel(buffer);
37 (MeteredSender::new(tx, scope), MeteredReceiver::new(rx, scope))
38}
39
40#[derive(Debug)]
42pub struct UnboundedMeteredSender<T> {
43 sender: mpsc::UnboundedSender<T>,
45 metrics: MeteredSenderMetrics,
47}
48
49impl<T> UnboundedMeteredSender<T> {
50 pub fn new(sender: mpsc::UnboundedSender<T>, scope: &'static str) -> Self {
53 Self { sender, metrics: MeteredSenderMetrics::new(scope) }
54 }
55
56 pub fn send(&self, message: T) -> Result<(), SendError<T>> {
59 match self.sender.send(message) {
60 Ok(()) => {
61 self.metrics.messages_sent_total.increment(1);
62 Ok(())
63 }
64 Err(error) => {
65 self.metrics.send_errors_total.increment(1);
66 Err(error)
67 }
68 }
69 }
70}
71
72impl<T> Clone for UnboundedMeteredSender<T> {
73 fn clone(&self) -> Self {
74 Self { sender: self.sender.clone(), metrics: self.metrics.clone() }
75 }
76}
77
78#[derive(Debug)]
81pub struct UnboundedMeteredReceiver<T> {
82 receiver: mpsc::UnboundedReceiver<T>,
84 metrics: MeteredReceiverMetrics,
86}
87
88impl<T> UnboundedMeteredReceiver<T> {
91 pub fn new(receiver: mpsc::UnboundedReceiver<T>, scope: &'static str) -> Self {
94 Self { receiver, metrics: MeteredReceiverMetrics::new(scope) }
95 }
96
97 pub async fn recv(&mut self) -> Option<T> {
99 let msg = self.receiver.recv().await;
100 if msg.is_some() {
101 self.metrics.messages_received_total.increment(1);
102 }
103 msg
104 }
105
106 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
108 let msg = self.receiver.try_recv()?;
109 self.metrics.messages_received_total.increment(1);
110 Ok(msg)
111 }
112
113 pub fn close(&mut self) {
115 self.receiver.close();
116 }
117
118 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
120 let msg = ready!(self.receiver.poll_recv(cx));
121 if msg.is_some() {
122 self.metrics.messages_received_total.increment(1);
123 }
124 Poll::Ready(msg)
125 }
126}
127
128impl<T> Stream for UnboundedMeteredReceiver<T> {
129 type Item = T;
130
131 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
132 self.poll_recv(cx)
133 }
134}
135
136#[derive(Debug)]
138pub struct MeteredSender<T> {
139 sender: mpsc::Sender<T>,
141 metrics: MeteredSenderMetrics,
143}
144
145impl<T> MeteredSender<T> {
146 pub fn new(sender: mpsc::Sender<T>, scope: &'static str) -> Self {
148 Self { sender, metrics: MeteredSenderMetrics::new(scope) }
149 }
150
151 pub fn try_reserve_owned(self) -> Result<OwnedPermit<T>, TrySendError<Self>> {
155 let Self { sender, metrics } = self;
156 sender.try_reserve_owned().map(|permit| OwnedPermit::new(permit, metrics.clone())).map_err(
157 |err| match err {
158 TrySendError::Full(sender) => TrySendError::Full(Self { sender, metrics }),
159 TrySendError::Closed(sender) => TrySendError::Closed(Self { sender, metrics }),
160 },
161 )
162 }
163
164 pub async fn reserve_owned(self) -> Result<OwnedPermit<T>, SendError<()>> {
168 self.sender.reserve_owned().await.map(|permit| OwnedPermit::new(permit, self.metrics))
169 }
170
171 pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
175 self.sender.reserve().await.map(|permit| Permit::new(permit, &self.metrics))
176 }
177
178 pub fn try_reserve(&self) -> Result<Permit<'_, T>, TrySendError<()>> {
182 self.sender.try_reserve().map(|permit| Permit::new(permit, &self.metrics))
183 }
184
185 pub const fn inner(&self) -> &mpsc::Sender<T> {
187 &self.sender
188 }
189
190 pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
193 match self.sender.try_send(message) {
194 Ok(()) => {
195 self.metrics.messages_sent_total.increment(1);
196 Ok(())
197 }
198 Err(error) => {
199 self.metrics.send_errors_total.increment(1);
200 Err(error)
201 }
202 }
203 }
204
205 pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
208 match self.sender.send(value).await {
209 Ok(()) => {
210 self.metrics.messages_sent_total.increment(1);
211 Ok(())
212 }
213 Err(error) => {
214 self.metrics.send_errors_total.increment(1);
215 Err(error)
216 }
217 }
218 }
219}
220
221impl<T> Clone for MeteredSender<T> {
222 fn clone(&self) -> Self {
223 Self { sender: self.sender.clone(), metrics: self.metrics.clone() }
224 }
225}
226
227#[derive(Debug)]
230pub struct OwnedPermit<T> {
231 permit: mpsc::OwnedPermit<T>,
232 metrics: MeteredSenderMetrics,
234}
235
236impl<T> OwnedPermit<T> {
237 pub const fn new(permit: mpsc::OwnedPermit<T>, metrics: MeteredSenderMetrics) -> Self {
240 Self { permit, metrics }
241 }
242
243 pub fn send(self, value: T) -> MeteredSender<T> {
245 let Self { permit, metrics } = self;
246 metrics.messages_sent_total.increment(1);
247 MeteredSender { sender: permit.send(value), metrics }
248 }
249}
250
251#[derive(Debug)]
254pub struct Permit<'a, T> {
255 permit: mpsc::Permit<'a, T>,
256 metrics_ref: &'a MeteredSenderMetrics,
257}
258
259impl<'a, T> Permit<'a, T> {
260 pub const fn new(permit: mpsc::Permit<'a, T>, metrics_ref: &'a MeteredSenderMetrics) -> Self {
262 Self { permit, metrics_ref }
263 }
264
265 pub fn send(self, value: T) {
267 self.metrics_ref.messages_sent_total.increment(1);
268 self.permit.send(value);
269 }
270}
271
272#[derive(Debug)]
274pub struct MeteredReceiver<T> {
275 receiver: mpsc::Receiver<T>,
277 metrics: MeteredReceiverMetrics,
279}
280
281impl<T> MeteredReceiver<T> {
284 pub fn new(receiver: mpsc::Receiver<T>, scope: &'static str) -> Self {
286 Self { receiver, metrics: MeteredReceiverMetrics::new(scope) }
287 }
288
289 pub async fn recv(&mut self) -> Option<T> {
291 let msg = self.receiver.recv().await;
292 if msg.is_some() {
293 self.metrics.messages_received_total.increment(1);
294 }
295 msg
296 }
297
298 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
300 let msg = self.receiver.try_recv()?;
301 self.metrics.messages_received_total.increment(1);
302 Ok(msg)
303 }
304
305 pub fn close(&mut self) {
307 self.receiver.close();
308 }
309
310 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
312 let msg = ready!(self.receiver.poll_recv(cx));
313 if msg.is_some() {
314 self.metrics.messages_received_total.increment(1);
315 }
316 Poll::Ready(msg)
317 }
318}
319
320impl<T> Stream for MeteredReceiver<T> {
321 type Item = T;
322
323 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
324 self.poll_recv(cx)
325 }
326}
327
328#[derive(Clone, Metrics)]
330#[metrics(dynamic = true)]
331pub struct MeteredSenderMetrics {
332 messages_sent_total: Counter,
334 send_errors_total: Counter,
336}
337
338#[derive(Clone, Metrics)]
340#[metrics(dynamic = true)]
341struct MeteredReceiverMetrics {
342 messages_received_total: Counter,
344}
345
346#[derive(Debug)]
348pub struct MeteredPollSender<T> {
349 sender: PollSender<T>,
351 metrics: MeteredPollSenderMetrics,
353}
354
355impl<T: Send + 'static> MeteredPollSender<T> {
356 pub fn new(sender: PollSender<T>, scope: &'static str) -> Self {
358 Self { sender, metrics: MeteredPollSenderMetrics::new(scope) }
359 }
360
361 pub const fn inner(&self) -> &PollSender<T> {
363 &self.sender
364 }
365
366 pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> {
369 match self.sender.poll_reserve(cx) {
370 Poll::Ready(Ok(permit)) => Poll::Ready(Ok(permit)),
371 Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
372 Poll::Pending => {
373 self.metrics.back_pressure_total.increment(1);
374 Poll::Pending
375 }
376 }
377 }
378
379 pub fn send_item(&mut self, item: T) -> Result<(), PollSendError<T>> {
382 match self.sender.send_item(item) {
383 Ok(()) => {
384 self.metrics.messages_sent_total.increment(1);
385 Ok(())
386 }
387 Err(error) => Err(error),
388 }
389 }
390}
391
392impl<T> Clone for MeteredPollSender<T> {
393 fn clone(&self) -> Self {
394 Self { sender: self.sender.clone(), metrics: self.metrics.clone() }
395 }
396}
397
398#[derive(Clone, Metrics)]
400#[metrics(dynamic = true)]
401struct MeteredPollSenderMetrics {
402 messages_sent_total: Counter,
404 back_pressure_total: Counter,
406}
407
408#[derive(Debug)]
414struct MemoryBudget {
415 used: AtomicUsize,
417 max_bytes: usize,
419}
420
421#[derive(Debug)]
426struct BudgetGuard {
427 size: usize,
428 budget: Arc<MemoryBudget>,
429}
430
431impl Drop for BudgetGuard {
432 fn drop(&mut self) {
433 self.budget.used.fetch_sub(self.size, Ordering::Relaxed);
434 }
435}
436
437#[derive(Debug)]
444struct Budgeted<T> {
445 msg: T,
446 _guard: BudgetGuard,
447}
448
449#[derive(Debug, Clone)]
460pub struct MemoryBoundedSender<T: InMemorySize> {
461 inner: UnboundedMeteredSender<Budgeted<T>>,
463 budget: Arc<MemoryBudget>,
465}
466
467impl<T: InMemorySize> MemoryBoundedSender<T> {
468 pub fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
472 let size = msg.size();
473
474 let prev = self.budget.used.fetch_add(size, Ordering::Relaxed);
476 if prev.saturating_add(size) > self.budget.max_bytes {
477 self.budget.used.fetch_sub(size, Ordering::Relaxed);
479 return Err(TrySendError::Full(msg));
480 }
481
482 let guard = BudgetGuard { size, budget: Arc::clone(&self.budget) };
483 let budgeted = Budgeted { msg, _guard: guard };
484
485 self.inner.send(budgeted).map_err(|e| {
486 TrySendError::Closed(e.0.msg)
488 })
489 }
490}
491
492#[derive(Debug)]
497pub struct MemoryBoundedReceiver<T> {
498 inner: UnboundedMeteredReceiver<Budgeted<T>>,
500}
501
502impl<T> MemoryBoundedReceiver<T> {
503 pub async fn recv(&mut self) -> Option<T> {
507 self.inner.recv().await.map(unwrap_budgeted)
508 }
509
510 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
514 self.inner.poll_recv(cx).map(|opt| opt.map(unwrap_budgeted))
515 }
516}
517
518fn unwrap_budgeted<T>(b: Budgeted<T>) -> T {
520 let Budgeted { msg, _guard } = b;
523 msg
524}
525
526impl<T> Stream for MemoryBoundedReceiver<T> {
527 type Item = T;
528
529 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
530 self.poll_recv(cx)
531 }
532}
533
534pub fn memory_bounded_channel<T: InMemorySize>(
540 max_bytes: usize,
541 scope: &'static str,
542) -> (MemoryBoundedSender<T>, MemoryBoundedReceiver<T>) {
543 let (tx, rx) = metered_unbounded_channel(scope);
544 let budget = Arc::new(MemoryBudget { used: AtomicUsize::new(0), max_bytes });
545
546 let sender = MemoryBoundedSender { inner: tx, budget };
547 let receiver = MemoryBoundedReceiver { inner: rx };
548
549 (sender, receiver)
550}