1use crate::Metrics;
5use futures::Stream;
6use metrics::Counter;
7use std::{
8 pin::Pin,
9 task::{ready, Context, Poll},
10};
11use tokio::sync::mpsc::{
12 self,
13 error::{SendError, TryRecvError, TrySendError},
14};
15use tokio_util::sync::{PollSendError, PollSender};
16
17pub fn metered_unbounded_channel<T>(
19 scope: &'static str,
20) -> (UnboundedMeteredSender<T>, UnboundedMeteredReceiver<T>) {
21 let (tx, rx) = mpsc::unbounded_channel();
22 (UnboundedMeteredSender::new(tx, scope), UnboundedMeteredReceiver::new(rx, scope))
23}
24
25pub fn metered_channel<T>(
28 buffer: usize,
29 scope: &'static str,
30) -> (MeteredSender<T>, MeteredReceiver<T>) {
31 let (tx, rx) = mpsc::channel(buffer);
32 (MeteredSender::new(tx, scope), MeteredReceiver::new(rx, scope))
33}
34
35#[derive(Debug)]
37pub struct UnboundedMeteredSender<T> {
38 sender: mpsc::UnboundedSender<T>,
40 metrics: MeteredSenderMetrics,
42}
43
44impl<T> UnboundedMeteredSender<T> {
45 pub fn new(sender: mpsc::UnboundedSender<T>, scope: &'static str) -> Self {
48 Self { sender, metrics: MeteredSenderMetrics::new(scope) }
49 }
50
51 pub fn send(&self, message: T) -> Result<(), SendError<T>> {
54 match self.sender.send(message) {
55 Ok(()) => {
56 self.metrics.messages_sent_total.increment(1);
57 Ok(())
58 }
59 Err(error) => {
60 self.metrics.send_errors_total.increment(1);
61 Err(error)
62 }
63 }
64 }
65}
66
67impl<T> Clone for UnboundedMeteredSender<T> {
68 fn clone(&self) -> Self {
69 Self { sender: self.sender.clone(), metrics: self.metrics.clone() }
70 }
71}
72
73#[derive(Debug)]
75pub struct UnboundedMeteredReceiver<T> {
76 receiver: mpsc::UnboundedReceiver<T>,
78 metrics: MeteredReceiverMetrics,
80}
81
82impl<T> UnboundedMeteredReceiver<T> {
85 pub fn new(receiver: mpsc::UnboundedReceiver<T>, scope: &'static str) -> Self {
88 Self { receiver, metrics: MeteredReceiverMetrics::new(scope) }
89 }
90
91 pub async fn recv(&mut self) -> Option<T> {
93 let msg = self.receiver.recv().await;
94 if msg.is_some() {
95 self.metrics.messages_received_total.increment(1);
96 }
97 msg
98 }
99
100 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
102 let msg = self.receiver.try_recv()?;
103 self.metrics.messages_received_total.increment(1);
104 Ok(msg)
105 }
106
107 pub fn close(&mut self) {
109 self.receiver.close();
110 }
111
112 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
114 let msg = ready!(self.receiver.poll_recv(cx));
115 if msg.is_some() {
116 self.metrics.messages_received_total.increment(1);
117 }
118 Poll::Ready(msg)
119 }
120}
121
122impl<T> Stream for UnboundedMeteredReceiver<T> {
123 type Item = T;
124
125 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
126 self.poll_recv(cx)
127 }
128}
129
130#[derive(Debug)]
132pub struct MeteredSender<T> {
133 sender: mpsc::Sender<T>,
135 metrics: MeteredSenderMetrics,
137}
138
139impl<T> MeteredSender<T> {
140 pub fn new(sender: mpsc::Sender<T>, scope: &'static str) -> Self {
142 Self { sender, metrics: MeteredSenderMetrics::new(scope) }
143 }
144
145 pub fn try_reserve_owned(self) -> Result<OwnedPermit<T>, TrySendError<Self>> {
149 let Self { sender, metrics } = self;
150 sender.try_reserve_owned().map(|permit| OwnedPermit::new(permit, metrics.clone())).map_err(
151 |err| match err {
152 TrySendError::Full(sender) => TrySendError::Full(Self { sender, metrics }),
153 TrySendError::Closed(sender) => TrySendError::Closed(Self { sender, metrics }),
154 },
155 )
156 }
157
158 pub async fn reserve_owned(self) -> Result<OwnedPermit<T>, SendError<()>> {
162 self.sender.reserve_owned().await.map(|permit| OwnedPermit::new(permit, self.metrics))
163 }
164
165 pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
169 self.sender.reserve().await.map(|permit| Permit::new(permit, &self.metrics))
170 }
171
172 pub fn try_reserve(&self) -> Result<Permit<'_, T>, TrySendError<()>> {
176 self.sender.try_reserve().map(|permit| Permit::new(permit, &self.metrics))
177 }
178
179 pub const fn inner(&self) -> &mpsc::Sender<T> {
181 &self.sender
182 }
183
184 pub fn try_send(&self, message: T) -> Result<(), TrySendError<T>> {
187 match self.sender.try_send(message) {
188 Ok(()) => {
189 self.metrics.messages_sent_total.increment(1);
190 Ok(())
191 }
192 Err(error) => {
193 self.metrics.send_errors_total.increment(1);
194 Err(error)
195 }
196 }
197 }
198
199 pub async fn send(&self, value: T) -> Result<(), SendError<T>> {
202 match self.sender.send(value).await {
203 Ok(()) => {
204 self.metrics.messages_sent_total.increment(1);
205 Ok(())
206 }
207 Err(error) => {
208 self.metrics.send_errors_total.increment(1);
209 Err(error)
210 }
211 }
212 }
213}
214
215impl<T> Clone for MeteredSender<T> {
216 fn clone(&self) -> Self {
217 Self { sender: self.sender.clone(), metrics: self.metrics.clone() }
218 }
219}
220
221#[derive(Debug)]
224pub struct OwnedPermit<T> {
225 permit: mpsc::OwnedPermit<T>,
226 metrics: MeteredSenderMetrics,
228}
229
230impl<T> OwnedPermit<T> {
231 pub const fn new(permit: mpsc::OwnedPermit<T>, metrics: MeteredSenderMetrics) -> Self {
234 Self { permit, metrics }
235 }
236
237 pub fn send(self, value: T) -> MeteredSender<T> {
239 let Self { permit, metrics } = self;
240 metrics.messages_sent_total.increment(1);
241 MeteredSender { sender: permit.send(value), metrics }
242 }
243}
244
245#[derive(Debug)]
248pub struct Permit<'a, T> {
249 permit: mpsc::Permit<'a, T>,
250 metrics_ref: &'a MeteredSenderMetrics,
251}
252
253impl<'a, T> Permit<'a, T> {
254 pub const fn new(permit: mpsc::Permit<'a, T>, metrics_ref: &'a MeteredSenderMetrics) -> Self {
256 Self { permit, metrics_ref }
257 }
258
259 pub fn send(self, value: T) {
261 self.metrics_ref.messages_sent_total.increment(1);
262 self.permit.send(value);
263 }
264}
265
266#[derive(Debug)]
268pub struct MeteredReceiver<T> {
269 receiver: mpsc::Receiver<T>,
271 metrics: MeteredReceiverMetrics,
273}
274
275impl<T> MeteredReceiver<T> {
278 pub fn new(receiver: mpsc::Receiver<T>, scope: &'static str) -> Self {
280 Self { receiver, metrics: MeteredReceiverMetrics::new(scope) }
281 }
282
283 pub async fn recv(&mut self) -> Option<T> {
285 let msg = self.receiver.recv().await;
286 if msg.is_some() {
287 self.metrics.messages_received_total.increment(1);
288 }
289 msg
290 }
291
292 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
294 let msg = self.receiver.try_recv()?;
295 self.metrics.messages_received_total.increment(1);
296 Ok(msg)
297 }
298
299 pub fn close(&mut self) {
301 self.receiver.close();
302 }
303
304 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
306 let msg = ready!(self.receiver.poll_recv(cx));
307 if msg.is_some() {
308 self.metrics.messages_received_total.increment(1);
309 }
310 Poll::Ready(msg)
311 }
312}
313
314impl<T> Stream for MeteredReceiver<T> {
315 type Item = T;
316
317 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
318 self.poll_recv(cx)
319 }
320}
321
322#[derive(Clone, Metrics)]
324#[metrics(dynamic = true)]
325pub struct MeteredSenderMetrics {
326 messages_sent_total: Counter,
328 send_errors_total: Counter,
330}
331
332#[derive(Clone, Metrics)]
334#[metrics(dynamic = true)]
335struct MeteredReceiverMetrics {
336 messages_received_total: Counter,
338}
339
340#[derive(Debug)]
342pub struct MeteredPollSender<T> {
343 sender: PollSender<T>,
345 metrics: MeteredPollSenderMetrics,
347}
348
349impl<T: Send + 'static> MeteredPollSender<T> {
350 pub fn new(sender: PollSender<T>, scope: &'static str) -> Self {
352 Self { sender, metrics: MeteredPollSenderMetrics::new(scope) }
353 }
354
355 pub const fn inner(&self) -> &PollSender<T> {
357 &self.sender
358 }
359
360 pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> {
363 match self.sender.poll_reserve(cx) {
364 Poll::Ready(Ok(permit)) => Poll::Ready(Ok(permit)),
365 Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
366 Poll::Pending => {
367 self.metrics.back_pressure_total.increment(1);
368 Poll::Pending
369 }
370 }
371 }
372
373 pub fn send_item(&mut self, item: T) -> Result<(), PollSendError<T>> {
376 match self.sender.send_item(item) {
377 Ok(()) => {
378 self.metrics.messages_sent_total.increment(1);
379 Ok(())
380 }
381 Err(error) => Err(error),
382 }
383 }
384}
385
386impl<T> Clone for MeteredPollSender<T> {
387 fn clone(&self) -> Self {
388 Self { sender: self.sender.clone(), metrics: self.metrics.clone() }
389 }
390}
391
392#[derive(Clone, Metrics)]
394#[metrics(dynamic = true)]
395struct MeteredPollSenderMetrics {
396 messages_sent_total: Counter,
398 back_pressure_total: Counter,
400}