1use crate::Metrics;
5use futures::Stream;
6use metrics::Counter;
7use std::{
8 pin::Pin,
9 task::{ready, Context, Poll},
10};
11use tokio::sync::mpsc::{
12 self,
13 error::{SendError, TryRecvError, TrySendError},
14};
15use tokio_util::sync::{PollSendError, PollSender};
16
17pub fn metered_unbounded_channel<T>(
19 scope: &'static str,
20) -> (UnboundedMeteredSender<T>, UnboundedMeteredReceiver<T>) {
21 let (tx, rx) = mpsc::unbounded_channel();
22 (UnboundedMeteredSender::new(tx, scope), UnboundedMeteredReceiver::new(rx, scope))
23}
24
25pub fn metered_channel<T>(
28 buffer: usize,
29 scope: &'static str,
30) -> (MeteredSender<T>, MeteredReceiver<T>) {
31 let (tx, rx) = mpsc::channel(buffer);
32 (MeteredSender::new(tx, scope), MeteredReceiver::new(rx, scope))
33}
34
35#[derive(Debug)]
37pub struct UnboundedMeteredSender<T> {
38 sender: mpsc::UnboundedSender<T>,
40 metrics: MeteredSenderMetrics,
42}
43
44impl<T> UnboundedMeteredSender<T> {
45 pub fn new(sender: mpsc::UnboundedSender<T>, scope: &'static str) -> Self {
48 Self { sender, metrics: MeteredSenderMetrics::new(scope) }
49 }
50
51 pub fn send(&self, message: T) -> Result<(), SendError<T>> {
55 match self.sender.send(message) {
56 Ok(()) => {
57 self.metrics.messages_sent_total.increment(1);
58 Ok(())
59 }
60 Err(error) => {
61 self.metrics.send_errors_total.increment(1);
62 Err(error)
63 }
64 }
65 }
66}
67
68impl<T> Clone for UnboundedMeteredSender<T> {
69 fn clone(&self) -> Self {
70 Self { sender: self.sender.clone(), metrics: self.metrics.clone() }
71 }
72}
73
74#[derive(Debug)]
76pub struct UnboundedMeteredReceiver<T> {
77 receiver: mpsc::UnboundedReceiver<T>,
79 metrics: MeteredReceiverMetrics,
81}
82
83impl<T> UnboundedMeteredReceiver<T> {
86 pub fn new(receiver: mpsc::UnboundedReceiver<T>, scope: &'static str) -> Self {
89 Self { receiver, metrics: MeteredReceiverMetrics::new(scope) }
90 }
91
92 pub async fn recv(&mut self) -> Option<T> {
94 let msg = self.receiver.recv().await;
95 if msg.is_some() {
96 self.metrics.messages_received_total.increment(1);
97 }
98 msg
99 }
100
101 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
103 let msg = self.receiver.try_recv()?;
104 self.metrics.messages_received_total.increment(1);
105 Ok(msg)
106 }
107
108 pub fn close(&mut self) {
110 self.receiver.close();
111 }
112
113 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
115 let msg = ready!(self.receiver.poll_recv(cx));
116 if msg.is_some() {
117 self.metrics.messages_received_total.increment(1);
118 }
119 Poll::Ready(msg)
120 }
121}
122
123impl<T> Stream for UnboundedMeteredReceiver<T> {
124 type Item = T;
125
126 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127 self.poll_recv(cx)
128 }
129}
130
131#[derive(Debug)]
133pub struct MeteredSender<T> {
134 sender: mpsc::Sender<T>,
136 metrics: MeteredSenderMetrics,
138}
139
140impl<T> MeteredSender<T> {
141 pub fn new(sender: mpsc::Sender<T>, scope: &'static str) -> Self {
143 Self { sender, metrics: MeteredSenderMetrics::new(scope) }
144 }
145
146 pub fn try_reserve_owned(self) -> Result<OwnedPermit<T>, TrySendError<Self>> {
150 let Self { sender, metrics } = self;
151 sender.try_reserve_owned().map(|permit| OwnedPermit::new(permit, metrics.clone())).map_err(
152 |err| match err {
153 TrySendError::Full(sender) => TrySendError::Full(Self { sender, metrics }),
154 TrySendError::Closed(sender) => TrySendError::Closed(Self { sender, metrics }),
155 },
156 )
157 }
158
159 pub async fn reserve_owned(self) -> Result<OwnedPermit<T>, SendError<()>> {
163 self.sender.reserve_owned().await.map(|permit| OwnedPermit::new(permit, self.metrics))
164 }
165
166 pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
170 self.sender.reserve().await.map(|permit| Permit::new(permit, &self.metrics))
171 }
172
173 pub fn try_reserve(&self) -> Result<Permit<'_, T>, TrySendError<()>> {
177 self.sender.try_reserve().map(|permit| Permit::new(permit, &self.metrics))
178 }
179
180 pub const fn inner(&self) -> &mpsc::Sender<T> {
182 &self.sender
183 }
184
185 pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
188 match self.sender.try_send(message) {
189 Ok(()) => {
190 self.metrics.messages_sent_total.increment(1);
191 Ok(())
192 }
193 Err(error) => {
194 self.metrics.send_errors_total.increment(1);
195 Err(error)
196 }
197 }
198 }
199
200 pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
203 match self.sender.send(value).await {
204 Ok(()) => {
205 self.metrics.messages_sent_total.increment(1);
206 Ok(())
207 }
208 Err(error) => {
209 self.metrics.send_errors_total.increment(1);
210 Err(error)
211 }
212 }
213 }
214}
215
216impl<T> Clone for MeteredSender<T> {
217 fn clone(&self) -> Self {
218 Self { sender: self.sender.clone(), metrics: self.metrics.clone() }
219 }
220}
221
222#[derive(Debug)]
225pub struct OwnedPermit<T> {
226 permit: mpsc::OwnedPermit<T>,
227 metrics: MeteredSenderMetrics,
229}
230
231impl<T> OwnedPermit<T> {
232 pub const fn new(permit: mpsc::OwnedPermit<T>, metrics: MeteredSenderMetrics) -> Self {
235 Self { permit, metrics }
236 }
237
238 pub fn send(self, value: T) -> MeteredSender<T> {
240 let Self { permit, metrics } = self;
241 metrics.messages_sent_total.increment(1);
242 MeteredSender { sender: permit.send(value), metrics }
243 }
244}
245
246#[derive(Debug)]
249pub struct Permit<'a, T> {
250 permit: mpsc::Permit<'a, T>,
251 metrics_ref: &'a MeteredSenderMetrics,
252}
253
254impl<'a, T> Permit<'a, T> {
255 pub const fn new(permit: mpsc::Permit<'a, T>, metrics_ref: &'a MeteredSenderMetrics) -> Self {
257 Self { permit, metrics_ref }
258 }
259
260 pub fn send(self, value: T) {
262 self.metrics_ref.messages_sent_total.increment(1);
263 self.permit.send(value);
264 }
265}
266
267#[derive(Debug)]
269pub struct MeteredReceiver<T> {
270 receiver: mpsc::Receiver<T>,
272 metrics: MeteredReceiverMetrics,
274}
275
276impl<T> MeteredReceiver<T> {
279 pub fn new(receiver: mpsc::Receiver<T>, scope: &'static str) -> Self {
281 Self { receiver, metrics: MeteredReceiverMetrics::new(scope) }
282 }
283
284 pub async fn recv(&mut self) -> Option<T> {
286 let msg = self.receiver.recv().await;
287 if msg.is_some() {
288 self.metrics.messages_received_total.increment(1);
289 }
290 msg
291 }
292
293 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
295 let msg = self.receiver.try_recv()?;
296 self.metrics.messages_received_total.increment(1);
297 Ok(msg)
298 }
299
300 pub fn close(&mut self) {
302 self.receiver.close();
303 }
304
305 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
307 let msg = ready!(self.receiver.poll_recv(cx));
308 if msg.is_some() {
309 self.metrics.messages_received_total.increment(1);
310 }
311 Poll::Ready(msg)
312 }
313}
314
315impl<T> Stream for MeteredReceiver<T> {
316 type Item = T;
317
318 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
319 self.poll_recv(cx)
320 }
321}
322
323#[derive(Clone, Metrics)]
325#[metrics(dynamic = true)]
326pub struct MeteredSenderMetrics {
327 messages_sent_total: Counter,
329 send_errors_total: Counter,
331}
332
333#[derive(Clone, Metrics)]
335#[metrics(dynamic = true)]
336struct MeteredReceiverMetrics {
337 messages_received_total: Counter,
339}
340
341#[derive(Debug)]
343pub struct MeteredPollSender<T> {
344 sender: PollSender<T>,
346 metrics: MeteredPollSenderMetrics,
348}
349
350impl<T: Send + 'static> MeteredPollSender<T> {
351 pub fn new(sender: PollSender<T>, scope: &'static str) -> Self {
353 Self { sender, metrics: MeteredPollSenderMetrics::new(scope) }
354 }
355
356 pub const fn inner(&self) -> &PollSender<T> {
358 &self.sender
359 }
360
361 pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> {
364 match self.sender.poll_reserve(cx) {
365 Poll::Ready(Ok(permit)) => Poll::Ready(Ok(permit)),
366 Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
367 Poll::Pending => {
368 self.metrics.back_pressure_total.increment(1);
369 Poll::Pending
370 }
371 }
372 }
373
374 pub fn send_item(&mut self, item: T) -> Result<(), PollSendError<T>> {
377 match self.sender.send_item(item) {
378 Ok(()) => {
379 self.metrics.messages_sent_total.increment(1);
380 Ok(())
381 }
382 Err(error) => Err(error),
383 }
384 }
385}
386
387impl<T> Clone for MeteredPollSender<T> {
388 fn clone(&self) -> Self {
389 Self { sender: self.sender.clone(), metrics: self.metrics.clone() }
390 }
391}
392
393#[derive(Clone, Metrics)]
395#[metrics(dynamic = true)]
396struct MeteredPollSenderMetrics {
397 messages_sent_total: Counter,
399 back_pressure_total: Counter,
401}