1use std::{
11 collections::VecDeque,
12 fmt,
13 future::Future,
14 io,
15 pin::{pin, Pin},
16 sync::Arc,
17 task::{ready, Context, Poll},
18};
19
20use crate::{
21 capability::{SharedCapabilities, SharedCapability, UnsupportedCapabilityError},
22 errors::{EthStreamError, P2PStreamError},
23 handshake::EthRlpxHandshake,
24 p2pstream::DisconnectP2P,
25 CanDisconnect, Capability, DisconnectReason, EthStream, P2PStream, UnifiedStatus,
26 HANDSHAKE_TIMEOUT,
27};
28use bytes::{Bytes, BytesMut};
29use futures::{Sink, SinkExt, Stream, StreamExt, TryStream, TryStreamExt};
30use reth_eth_wire_types::NetworkPrimitives;
31use reth_ethereum_forks::ForkFilter;
32use tokio::sync::{mpsc, mpsc::UnboundedSender};
33use tokio_stream::wrappers::UnboundedReceiverStream;
34
35#[derive(Debug)]
38pub struct RlpxProtocolMultiplexer<St> {
39 inner: MultiplexInner<St>,
40}
41
42impl<St> RlpxProtocolMultiplexer<St> {
43 pub fn new(conn: P2PStream<St>) -> Self {
45 Self {
46 inner: MultiplexInner {
47 conn,
48 protocols: Default::default(),
49 out_buffer: Default::default(),
50 },
51 }
52 }
53
54 pub fn install_protocol<F, Proto>(
59 &mut self,
60 cap: &Capability,
61 f: F,
62 ) -> Result<(), UnsupportedCapabilityError>
63 where
64 F: FnOnce(ProtocolConnection) -> Proto,
65 Proto: Stream<Item = BytesMut> + Send + 'static,
66 {
67 self.inner.install_protocol(cap, f)
68 }
69
70 pub const fn shared_capabilities(&self) -> &SharedCapabilities {
72 self.inner.shared_capabilities()
73 }
74
75 pub fn into_satellite_stream<F, Primary>(
77 self,
78 cap: &Capability,
79 primary: F,
80 ) -> Result<RlpxSatelliteStream<St, Primary>, P2PStreamError>
81 where
82 F: FnOnce(ProtocolProxy) -> Primary,
83 {
84 let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
85 else {
86 return Err(P2PStreamError::CapabilityNotShared)
87 };
88
89 let (to_primary, from_wire) = mpsc::unbounded_channel();
90 let (to_wire, from_primary) = mpsc::unbounded_channel();
91 let proxy = ProtocolProxy {
92 shared_cap: shared_cap.clone(),
93 from_wire: UnboundedReceiverStream::new(from_wire),
94 to_wire,
95 };
96
97 let st = primary(proxy);
98 Ok(RlpxSatelliteStream {
99 inner: self.inner,
100 primary: PrimaryProtocol {
101 to_primary,
102 from_primary: UnboundedReceiverStream::new(from_primary),
103 st,
104 shared_cap,
105 },
106 next_outbound: 0,
107 })
108 }
109
110 pub async fn into_satellite_stream_with_handshake<F, Fut, Err, Primary>(
115 self,
116 cap: &Capability,
117 handshake: F,
118 ) -> Result<RlpxSatelliteStream<St, Primary>, Err>
119 where
120 F: FnOnce(ProtocolProxy) -> Fut,
121 Fut: Future<Output = Result<Primary, Err>>,
122 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
123 P2PStreamError: Into<Err>,
124 {
125 self.into_satellite_stream_with_tuple_handshake(cap, async move |proxy| {
126 let st = handshake(proxy).await?;
127 Ok((st, ()))
128 })
129 .await
130 .map(|(st, _)| st)
131 }
132
133 pub async fn into_satellite_stream_with_tuple_handshake<F, Fut, Err, Primary, Extra>(
143 mut self,
144 cap: &Capability,
145 handshake: F,
146 ) -> Result<(RlpxSatelliteStream<St, Primary>, Extra), Err>
147 where
148 F: FnOnce(ProtocolProxy) -> Fut,
149 Fut: Future<Output = Result<(Primary, Extra), Err>>,
150 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
151 P2PStreamError: Into<Err>,
152 {
153 let Ok(shared_cap) = self.shared_capabilities().ensure_matching_capability(cap).cloned()
154 else {
155 return Err(P2PStreamError::CapabilityNotShared.into())
156 };
157
158 let (to_primary, from_wire) = mpsc::unbounded_channel();
159 let (to_wire, mut from_primary) = mpsc::unbounded_channel();
160 let proxy = ProtocolProxy {
161 shared_cap: shared_cap.clone(),
162 from_wire: UnboundedReceiverStream::new(from_wire),
163 to_wire,
164 };
165
166 let f = handshake(proxy);
167 let mut f = pin!(f);
168
169 loop {
172 tokio::select! {
173 biased;
174 Some(Ok(msg)) = self.inner.conn.next() => {
175 let Some(offset) = msg.first().copied()
177 else {
178 return Err(P2PStreamError::EmptyProtocolMessage.into())
179 };
180 if let Some(cap) = self.shared_capabilities().find_by_relative_offset(offset).cloned() {
181 if cap == shared_cap {
182 let _ = to_primary.send(msg);
184 } else {
185 self.inner.delegate_message(&cap, msg);
187 }
188 } else {
189 return Err(P2PStreamError::UnknownReservedMessageId(offset).into())
190 }
191 }
192 Some(msg) = from_primary.recv() => {
193 self.inner.conn.send(msg).await.map_err(Into::into)?;
194 }
195 msg = ProtocolsPoller::new(&mut self.inner.protocols) => {
197 self.inner.conn.send(msg.map_err(Into::into)?).await.map_err(Into::into)?;
198 }
199 res = &mut f => {
200 let (st, extra) = res?;
201 return Ok((
202 RlpxSatelliteStream {
203 inner: self.inner,
204 primary: PrimaryProtocol {
205 to_primary,
206 from_primary: UnboundedReceiverStream::new(from_primary),
207 st,
208 shared_cap,
209 },
210 next_outbound: 0,
211 },
212 extra,
213 ))
214 }
215 }
216 }
217 }
218
219 pub async fn into_eth_satellite_stream<N: NetworkPrimitives>(
222 self,
223 status: UnifiedStatus,
224 fork_filter: ForkFilter,
225 handshake: Arc<dyn EthRlpxHandshake>,
226 eth_max_message_size: usize,
227 ) -> Result<(RlpxSatelliteStream<St, EthStream<ProtocolProxy, N>>, UnifiedStatus), EthStreamError>
228 where
229 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
230 {
231 let eth_cap = self.inner.conn.shared_capabilities().eth_version()?;
232 self.into_satellite_stream_with_tuple_handshake(
233 &Capability::eth(eth_cap),
234 async move |proxy| {
235 let handshake = handshake.clone();
236 let mut unauth = UnauthProxy { inner: proxy };
237 let their_status = handshake
238 .handshake(&mut unauth, status, fork_filter, HANDSHAKE_TIMEOUT)
239 .await?;
240 let eth_stream = EthStream::with_max_message_size(
241 eth_cap,
242 unauth.into_inner(),
243 eth_max_message_size,
244 );
245 Ok((eth_stream, their_status))
246 },
247 )
248 .await
249 }
250}
251
252#[derive(Debug)]
253struct MultiplexInner<St> {
254 conn: P2PStream<St>,
256 protocols: VecDeque<ProtocolStream>,
258 out_buffer: OutBuffer,
260}
261
262impl<St> MultiplexInner<St> {
263 const fn shared_capabilities(&self) -> &SharedCapabilities {
264 self.conn.shared_capabilities()
265 }
266
267 fn delegate_message(&self, cap: &SharedCapability, msg: BytesMut) -> bool {
269 for proto in &self.protocols {
270 if proto.shared_cap == *cap {
271 proto.send_raw(msg);
272 return true
273 }
274 }
275 false
276 }
277
278 fn install_protocol<F, Proto>(
279 &mut self,
280 cap: &Capability,
281 f: F,
282 ) -> Result<(), UnsupportedCapabilityError>
283 where
284 F: FnOnce(ProtocolConnection) -> Proto,
285 Proto: Stream<Item = BytesMut> + Send + 'static,
286 {
287 let shared_cap =
288 self.conn.shared_capabilities().ensure_matching_capability(cap).cloned()?;
289 let (to_satellite, rx) = mpsc::unbounded_channel();
290 let proto_conn = ProtocolConnection { from_wire: UnboundedReceiverStream::new(rx) };
291 let st = f(proto_conn);
292 let st = ProtocolStream { shared_cap, to_satellite, satellite_st: Box::pin(st) };
293 self.protocols.push_back(st);
294 Ok(())
295 }
296}
297
298#[derive(Debug)]
300struct PrimaryProtocol<Primary> {
301 to_primary: UnboundedSender<BytesMut>,
303 from_primary: UnboundedReceiverStream<Bytes>,
305 shared_cap: SharedCapability,
307 st: Primary,
309}
310
311#[derive(Debug)]
315pub struct ProtocolProxy {
316 shared_cap: SharedCapability,
317 from_wire: UnboundedReceiverStream<BytesMut>,
319 to_wire: UnboundedSender<Bytes>,
321}
322
323impl ProtocolProxy {
324 fn try_send(&self, msg: Bytes) -> Result<(), io::Error> {
326 if msg.is_empty() {
327 return Err(io::ErrorKind::InvalidInput.into())
329 }
330 self.to_wire.send(self.mask_msg_id(msg)?).map_err(|_| io::ErrorKind::BrokenPipe.into())
331 }
332
333 #[inline]
335 fn mask_msg_id(&self, msg: Bytes) -> Result<Bytes, io::Error> {
336 if msg.is_empty() {
337 return Err(io::ErrorKind::InvalidInput.into())
339 }
340
341 let offset = self.shared_cap.relative_message_id_offset();
342 if offset == 0 {
343 return Ok(msg);
344 }
345
346 let mut masked: BytesMut = msg.into();
347 masked[0] = masked[0].checked_add(offset).ok_or(io::ErrorKind::InvalidInput)?;
348 Ok(masked.freeze())
349 }
350
351 #[inline]
353 fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
354 if msg.is_empty() {
355 return Err(io::ErrorKind::InvalidInput.into())
357 }
358 msg[0] = msg[0]
359 .checked_sub(self.shared_cap.relative_message_id_offset())
360 .ok_or(io::ErrorKind::InvalidInput)?;
361 Ok(msg)
362 }
363}
364
365impl Stream for ProtocolProxy {
366 type Item = Result<BytesMut, io::Error>;
367
368 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
369 let msg = ready!(self.from_wire.poll_next_unpin(cx));
370 Poll::Ready(msg.map(|msg| self.get_mut().unmask_id(msg)))
371 }
372}
373
374impl Sink<Bytes> for ProtocolProxy {
375 type Error = io::Error;
376
377 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
378 Poll::Ready(Ok(()))
379 }
380
381 fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
382 self.get_mut().try_send(item)
383 }
384
385 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
386 Poll::Ready(Ok(()))
387 }
388
389 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
390 Poll::Ready(Ok(()))
391 }
392}
393
394impl CanDisconnect<Bytes> for ProtocolProxy {
395 fn disconnect(
396 &mut self,
397 _reason: DisconnectReason,
398 ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
399 Box::pin(async move { Ok(()) })
400 }
401}
402
403#[derive(Debug)]
406struct UnauthProxy {
407 inner: ProtocolProxy,
408}
409
410impl UnauthProxy {
411 fn into_inner(self) -> ProtocolProxy {
412 self.inner
413 }
414}
415
416impl Stream for UnauthProxy {
417 type Item = Result<BytesMut, P2PStreamError>;
418
419 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
420 self.inner.poll_next_unpin(cx).map(|opt| opt.map(|res| res.map_err(P2PStreamError::from)))
421 }
422}
423
424impl Sink<Bytes> for UnauthProxy {
425 type Error = P2PStreamError;
426
427 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
428 self.inner.poll_ready_unpin(cx).map_err(P2PStreamError::from)
429 }
430
431 fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
432 self.inner.start_send_unpin(item).map_err(P2PStreamError::from)
433 }
434
435 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
436 self.inner.poll_flush_unpin(cx).map_err(P2PStreamError::from)
437 }
438
439 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
440 self.inner.poll_close_unpin(cx).map_err(P2PStreamError::from)
441 }
442}
443
444impl CanDisconnect<Bytes> for UnauthProxy {
445 fn disconnect(
446 &mut self,
447 reason: DisconnectReason,
448 ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<Bytes>>::Error>> + Send + '_>> {
449 let fut = self.inner.disconnect(reason);
450 Box::pin(async move { fut.await.map_err(P2PStreamError::from) })
451 }
452}
453
454#[derive(Debug)]
458pub struct ProtocolConnection {
459 from_wire: UnboundedReceiverStream<BytesMut>,
460}
461
462impl Stream for ProtocolConnection {
463 type Item = BytesMut;
464
465 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
466 self.from_wire.poll_next_unpin(cx)
467 }
468}
469
470#[derive(Debug)]
473pub struct RlpxSatelliteStream<St, Primary> {
474 inner: MultiplexInner<St>,
475 primary: PrimaryProtocol<Primary>,
476 next_outbound: usize,
478}
479
480impl<St, Primary> RlpxSatelliteStream<St, Primary> {
481 pub fn install_protocol<F, Proto>(
486 &mut self,
487 cap: &Capability,
488 f: F,
489 ) -> Result<(), UnsupportedCapabilityError>
490 where
491 F: FnOnce(ProtocolConnection) -> Proto,
492 Proto: Stream<Item = BytesMut> + Send + 'static,
493 {
494 self.inner.install_protocol(cap, f)
495 }
496
497 #[inline]
499 pub const fn primary(&self) -> &Primary {
500 &self.primary.st
501 }
502
503 #[inline]
505 pub const fn primary_mut(&mut self) -> &mut Primary {
506 &mut self.primary.st
507 }
508
509 #[inline]
511 pub const fn inner(&self) -> &P2PStream<St> {
512 &self.inner.conn
513 }
514
515 #[inline]
517 pub const fn inner_mut(&mut self) -> &mut P2PStream<St> {
518 &mut self.inner.conn
519 }
520
521 #[inline]
523 pub fn into_inner(self) -> P2PStream<St> {
524 self.inner.conn
525 }
526
527 fn poll_outbound_producers(&mut self, cx: &mut Context<'_>) -> Result<ProducerPoll, io::Error> {
533 let producers = self.inner.protocols.len() + 1;
534 let mut pending = 0;
535
536 while pending < producers {
537 if self.inner.out_buffer.is_full() {
538 return Ok(ProducerPoll::Full)
539 }
540
541 if self.next_outbound >= producers {
542 self.next_outbound = 0;
543 }
544
545 let producer = self.next_outbound;
546 self.next_outbound = (self.next_outbound + 1) % producers;
547
548 let msg = if producer == 0 {
549 match self.primary.from_primary.poll_next_unpin(cx) {
550 Poll::Ready(Some(msg)) => msg,
551 Poll::Ready(None) => return Ok(ProducerPoll::Closed),
552 Poll::Pending => {
553 pending += 1;
554 continue
555 }
556 }
557 } else {
558 let proto = self
559 .inner
560 .protocols
561 .get_mut(producer - 1)
562 .expect("outbound producer index checked against protocol count");
563 match proto.poll_next_unpin(cx) {
564 Poll::Ready(Some(msg)) => msg?,
565 Poll::Ready(None) => return Ok(ProducerPoll::Closed),
566 Poll::Pending => {
567 pending += 1;
568 continue
569 }
570 }
571 };
572
573 pending = 0;
574 self.inner.out_buffer.push_back(msg);
575 }
576
577 Ok(ProducerPoll::Pending)
578 }
579}
580
581impl<St, Primary, PrimaryErr> Stream for RlpxSatelliteStream<St, Primary>
582where
583 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
584 Primary: TryStream<Error = PrimaryErr> + Unpin,
585 P2PStreamError: Into<PrimaryErr>,
586{
587 type Item = Result<Primary::Ok, Primary::Error>;
588
589 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
590 let this = self.get_mut();
591
592 loop {
593 if let Poll::Ready(Some(msg)) = this.primary.st.try_poll_next_unpin(cx) {
595 return Poll::Ready(Some(msg))
596 }
597
598 let mut conn_ready = true;
599 loop {
600 match this.inner.conn.poll_ready_unpin(cx) {
601 Poll::Ready(Ok(())) => {
602 if let Some(msg) = this.inner.out_buffer.pop_front() {
603 if let Err(err) = this.inner.conn.start_send_unpin(msg) {
604 return Poll::Ready(Some(Err(err.into())))
605 }
606 } else {
607 break
608 }
609 }
610 Poll::Ready(Err(err)) => {
611 if let Err(disconnect_err) =
612 this.inner.conn.start_disconnect(DisconnectReason::DisconnectRequested)
613 {
614 return Poll::Ready(Some(Err(disconnect_err.into())))
615 }
616 return Poll::Ready(Some(Err(err.into())))
617 }
618 Poll::Pending => {
619 conn_ready = false;
620 break
621 }
622 }
623 }
624
625 match this.poll_outbound_producers(cx) {
626 Ok(ProducerPoll::Pending | ProducerPoll::Full) => {}
627 Ok(ProducerPoll::Closed) => return Poll::Ready(None),
628 Err(err) => return Poll::Ready(Some(Err(P2PStreamError::Io(err).into()))),
629 }
630
631 let mut delegated = false;
632 loop {
633 match this.inner.conn.poll_next_unpin(cx) {
635 Poll::Ready(Some(Ok(msg))) => {
636 delegated = true;
637 let Some(offset) = msg.first().copied() else {
638 return Poll::Ready(Some(Err(
639 P2PStreamError::EmptyProtocolMessage.into()
640 )))
641 };
642 if let Some(cap) =
644 this.inner.conn.shared_capabilities().find_by_relative_offset(offset)
645 {
646 if cap == &this.primary.shared_cap {
647 let _ = this.primary.to_primary.send(msg);
649 } else {
650 for proto in &this.inner.protocols {
652 if proto.shared_cap == *cap {
653 proto.send_raw(msg);
654 break
655 }
656 }
657 }
658 } else {
659 return Poll::Ready(Some(Err(P2PStreamError::UnknownReservedMessageId(
660 offset,
661 )
662 .into())))
663 }
664 }
665 Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
666 Poll::Ready(None) => {
667 return Poll::Ready(None)
669 }
670 Poll::Pending => break,
671 }
672 }
673
674 if !conn_ready || (!delegated && this.inner.out_buffer.is_empty()) {
675 return Poll::Pending
676 }
677 }
678 }
679}
680
681impl<St, Primary, T> Sink<T> for RlpxSatelliteStream<St, Primary>
682where
683 St: Stream<Item = io::Result<BytesMut>> + Sink<Bytes, Error = io::Error> + Unpin,
684 Primary: Sink<T> + Unpin,
685 P2PStreamError: Into<<Primary as Sink<T>>::Error>,
686{
687 type Error = <Primary as Sink<T>>::Error;
688
689 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
690 let this = self.get_mut();
691 if let Err(err) = ready!(this.inner.conn.poll_ready_unpin(cx)) {
692 return Poll::Ready(Err(err.into()))
693 }
694 if let Err(err) = ready!(this.primary.st.poll_ready_unpin(cx)) {
695 return Poll::Ready(Err(err))
696 }
697 Poll::Ready(Ok(()))
698 }
699
700 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
701 self.get_mut().primary.st.start_send_unpin(item)
702 }
703
704 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
705 self.get_mut().inner.conn.poll_flush_unpin(cx).map_err(Into::into)
706 }
707
708 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
709 self.get_mut().inner.conn.poll_close_unpin(cx).map_err(Into::into)
710 }
711}
712
713struct ProtocolStream {
715 shared_cap: SharedCapability,
716 to_satellite: UnboundedSender<BytesMut>,
718 satellite_st: Pin<Box<dyn Stream<Item = BytesMut> + Send>>,
719}
720
721impl ProtocolStream {
722 #[inline]
724 fn mask_msg_id(&self, mut msg: BytesMut) -> Result<Bytes, io::Error> {
725 if msg.is_empty() {
726 return Err(io::ErrorKind::InvalidInput.into())
728 }
729 msg[0] = msg[0]
730 .checked_add(self.shared_cap.relative_message_id_offset())
731 .ok_or(io::ErrorKind::InvalidInput)?;
732 Ok(msg.freeze())
733 }
734
735 #[inline]
737 fn unmask_id(&self, mut msg: BytesMut) -> Result<BytesMut, io::Error> {
738 if msg.is_empty() {
739 return Err(io::ErrorKind::InvalidInput.into())
741 }
742 msg[0] = msg[0]
743 .checked_sub(self.shared_cap.relative_message_id_offset())
744 .ok_or(io::ErrorKind::InvalidInput)?;
745 Ok(msg)
746 }
747
748 fn send_raw(&self, msg: BytesMut) {
750 let _ = self.unmask_id(msg).map(|msg| self.to_satellite.send(msg));
751 }
752}
753
754impl Stream for ProtocolStream {
755 type Item = Result<Bytes, io::Error>;
756
757 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
758 let this = self.get_mut();
759 let msg = ready!(this.satellite_st.as_mut().poll_next(cx));
760 Poll::Ready(msg.map(|msg| this.mask_msg_id(msg)))
761 }
762}
763
764impl fmt::Debug for ProtocolStream {
765 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
766 f.debug_struct("ProtocolStream").field("cap", &self.shared_cap).finish_non_exhaustive()
767 }
768}
769
770struct ProtocolsPoller<'a> {
772 protocols: &'a mut VecDeque<ProtocolStream>,
773}
774
775impl<'a> ProtocolsPoller<'a> {
776 const fn new(protocols: &'a mut VecDeque<ProtocolStream>) -> Self {
777 Self { protocols }
778 }
779}
780
781impl<'a> Future for ProtocolsPoller<'a> {
782 type Output = Result<Bytes, P2PStreamError>;
783
784 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
785 let protocols = self.protocols.len();
786 for _ in 0..protocols {
787 let mut proto = self.protocols.pop_front().expect("protocol count checked");
788 match proto.poll_next_unpin(cx) {
789 Poll::Ready(Some(Err(err))) => {
790 self.protocols.push_back(proto);
791 return Poll::Ready(Err(P2PStreamError::from(err)))
792 }
793 Poll::Ready(Some(Ok(msg))) => {
794 self.protocols.push_back(proto);
796 return Poll::Ready(Ok(msg));
797 }
798 _ => {
799 self.protocols.push_back(proto);
801 }
802 }
803 }
804
805 Poll::Pending
807 }
808}
809
810const MAX_MUX_OUT_BUFFER_BYTES: usize = 32 * 1024 * 1024;
819
820#[derive(Debug)]
821struct OutBuffer {
822 messages: VecDeque<Bytes>,
823 bytes: usize,
824 max_bytes: usize,
825}
826
827impl Default for OutBuffer {
828 fn default() -> Self {
829 Self { messages: Default::default(), bytes: 0, max_bytes: MAX_MUX_OUT_BUFFER_BYTES }
830 }
831}
832
833impl OutBuffer {
834 fn push_back(&mut self, msg: Bytes) {
835 self.bytes += msg.len();
836 self.messages.push_back(msg);
837 }
838
839 fn pop_front(&mut self) -> Option<Bytes> {
840 let msg = self.messages.pop_front()?;
841 self.bytes -= msg.len();
842 Some(msg)
843 }
844
845 fn is_empty(&self) -> bool {
846 self.messages.is_empty()
847 }
848
849 const fn is_full(&self) -> bool {
850 self.bytes >= self.max_bytes
851 }
852}
853
854#[derive(Clone, Copy, Debug, Eq, PartialEq)]
856enum ProducerPoll {
857 Pending,
859 Full,
861 Closed,
863}
864
865#[cfg(test)]
866mod tests {
867 use super::*;
868 use crate::{
869 handshake::EthHandshake,
870 message::MAX_MESSAGE_SIZE,
871 protocol::Protocol,
872 test_utils::{
873 connect_passthrough, eth_handshake, eth_hello,
874 proto::{test_hello, TestProtoMessage},
875 },
876 UnauthedEthStream, UnauthedP2PStream,
877 };
878 use futures::{stream, task::noop_waker_ref};
879 use reth_eth_wire_types::EthNetworkPrimitives;
880 use std::task::Poll;
881 use tokio::{net::TcpListener, sync::oneshot};
882 use tokio_util::codec::Decoder;
883
884 #[derive(Debug)]
885 struct PendingPrimary {
886 _proxy: ProtocolProxy,
887 }
888
889 impl Stream for PendingPrimary {
890 type Item = Result<(), P2PStreamError>;
891
892 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
893 Poll::Pending
894 }
895 }
896
897 #[derive(Debug)]
898 struct StalledTransport;
899
900 impl Stream for StalledTransport {
901 type Item = io::Result<BytesMut>;
902
903 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
904 Poll::Pending
905 }
906 }
907
908 impl Sink<Bytes> for StalledTransport {
909 type Error = io::Error;
910
911 fn poll_ready(
912 self: Pin<&mut Self>,
913 _cx: &mut Context<'_>,
914 ) -> Poll<Result<(), Self::Error>> {
915 Poll::Pending
916 }
917
918 fn start_send(self: Pin<&mut Self>, _item: Bytes) -> Result<(), Self::Error> {
919 Ok(())
920 }
921
922 fn poll_flush(
923 self: Pin<&mut Self>,
924 _cx: &mut Context<'_>,
925 ) -> Poll<Result<(), Self::Error>> {
926 Poll::Pending
927 }
928
929 fn poll_close(
930 self: Pin<&mut Self>,
931 _cx: &mut Context<'_>,
932 ) -> Poll<Result<(), Self::Error>> {
933 Poll::Pending
934 }
935 }
936
937 #[tokio::test]
938 async fn satellite_mux_stops_polling_protocols_when_out_buffer_is_full() {
939 let (hello, _) = test_hello();
940 let shared_capabilities =
941 SharedCapabilities::try_new(hello.protocols.clone(), hello.message().capabilities)
942 .unwrap();
943 let conn = P2PStream::new(StalledTransport, shared_capabilities);
944 let eth = conn.shared_capabilities().eth().unwrap().clone();
945
946 let mut st = RlpxProtocolMultiplexer::new(conn)
947 .into_satellite_stream(eth.capability().as_ref(), |proxy| PendingPrimary {
948 _proxy: proxy,
949 })
950 .unwrap();
951 const MESSAGE_COUNT: usize = 4096;
952 const MESSAGE_BYTES: usize = 1024;
953 st.inner.out_buffer.max_bytes = 4 * MESSAGE_BYTES + 1;
954 st.install_protocol(&TestProtoMessage::capability(), |_conn| {
955 stream::iter((0..MESSAGE_COUNT).map(|_| {
956 let mut msg = BytesMut::zeroed(MESSAGE_BYTES);
957 msg[0] = TestProtoMessage::ping().message_type as u8;
958 msg
959 }))
960 })
961 .unwrap();
962
963 let mut cx = Context::from_waker(noop_waker_ref());
964 assert!(Pin::new(&mut st).poll_next(&mut cx).is_pending());
965
966 assert!(st.inner.out_buffer.bytes > st.inner.out_buffer.max_bytes);
967 assert!(st.inner.out_buffer.bytes <= st.inner.out_buffer.max_bytes + MESSAGE_BYTES);
968 assert!(st.inner.out_buffer.messages.len() < MESSAGE_COUNT);
969 }
970
971 #[tokio::test]
972 async fn satellite_mux_round_robins_ready_protocols_when_out_buffer_fills() {
973 let (mut hello, _) = eth_hello();
974 let cap_a = Capability::new_static("aaa", 1);
975 let cap_b = Capability::new_static("bbb", 1);
976 hello.protocols.push(Protocol::new(cap_a.clone(), 1));
977 hello.protocols.push(Protocol::new(cap_b.clone(), 1));
978
979 let shared_capabilities =
980 SharedCapabilities::try_new(hello.protocols.clone(), hello.message().capabilities)
981 .unwrap();
982 let conn = P2PStream::new(StalledTransport, shared_capabilities);
983 let eth = conn.shared_capabilities().eth().unwrap().clone();
984 let cap_a_offset =
985 conn.shared_capabilities().find(&cap_a).unwrap().relative_message_id_offset();
986 let cap_b_offset =
987 conn.shared_capabilities().find(&cap_b).unwrap().relative_message_id_offset();
988
989 let mut st = RlpxProtocolMultiplexer::new(conn)
990 .into_satellite_stream(eth.capability().as_ref(), |proxy| PendingPrimary {
991 _proxy: proxy,
992 })
993 .unwrap();
994 st.inner.out_buffer.max_bytes = 5;
995 st.install_protocol(&cap_a, |_conn| {
996 stream::iter((0..16).map(|_| BytesMut::from(&[0, b'a'][..])))
997 })
998 .unwrap();
999 st.install_protocol(&cap_b, |_conn| {
1000 stream::iter((0..16).map(|_| BytesMut::from(&[0, b'b'][..])))
1001 })
1002 .unwrap();
1003
1004 let mut cx = Context::from_waker(noop_waker_ref());
1005 assert!(Pin::new(&mut st).poll_next(&mut cx).is_pending());
1006
1007 let message_ids =
1008 st.inner.out_buffer.messages.iter().take(2).map(|msg| msg[0]).collect::<Vec<_>>();
1009 assert_eq!(message_ids.len(), 2);
1010 assert_ne!(message_ids[0], message_ids[1]);
1011 assert!(message_ids.contains(&cap_a_offset));
1012 assert!(message_ids.contains(&cap_b_offset));
1013 }
1014
1015 #[tokio::test]
1016 async fn eth_satellite() {
1017 reth_tracing::init_test_tracing();
1018 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1019 let local_addr = listener.local_addr().unwrap();
1020 let (status, fork_filter) = eth_handshake();
1021 let other_status = status;
1022 let other_fork_filter = fork_filter.clone();
1023 let _handle = tokio::spawn(async move {
1024 let (incoming, _) = listener.accept().await.unwrap();
1025 let stream = crate::PassthroughCodec::default().framed(incoming);
1026 let (server_hello, _) = eth_hello();
1027 let (p2p_stream, _) =
1028 UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
1029
1030 let (_eth_stream, _) = UnauthedEthStream::new(p2p_stream)
1031 .handshake::<EthNetworkPrimitives>(other_status, other_fork_filter)
1032 .await
1033 .unwrap();
1034
1035 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1036 });
1037
1038 let conn = connect_passthrough(local_addr, eth_hello().0).await;
1039 let eth = conn.shared_capabilities().eth().unwrap().clone();
1040
1041 let multiplexer = RlpxProtocolMultiplexer::new(conn);
1042 let _satellite = multiplexer
1043 .into_satellite_stream_with_handshake(eth.capability().as_ref(), async move |proxy| {
1044 UnauthedEthStream::new(proxy)
1045 .handshake::<EthNetworkPrimitives>(status, fork_filter)
1046 .await
1047 })
1048 .await
1049 .unwrap();
1050 }
1051
1052 #[tokio::test(flavor = "multi_thread")]
1054 async fn eth_test_protocol_satellite() {
1055 reth_tracing::init_test_tracing();
1056 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1057 let local_addr = listener.local_addr().unwrap();
1058 let (status, fork_filter) = eth_handshake();
1059 let other_status = status;
1060 let other_fork_filter = fork_filter.clone();
1061 let _handle = tokio::spawn(async move {
1062 let (incoming, _) = listener.accept().await.unwrap();
1063 let stream = crate::PassthroughCodec::default().framed(incoming);
1064 let (server_hello, _) = test_hello();
1065 let (conn, _) = UnauthedP2PStream::new(stream).handshake(server_hello).await.unwrap();
1066
1067 let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
1068 .into_eth_satellite_stream::<EthNetworkPrimitives>(
1069 other_status,
1070 other_fork_filter,
1071 Arc::new(EthHandshake::default()),
1072 MAX_MESSAGE_SIZE,
1073 )
1074 .await
1075 .unwrap();
1076
1077 st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
1078 async_stream::stream! {
1079 yield TestProtoMessage::ping().encoded();
1080 let msg = conn.next().await.unwrap();
1081 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
1082 assert_eq!(msg, TestProtoMessage::pong());
1083
1084 yield TestProtoMessage::message("hello").encoded();
1085 let msg = conn.next().await.unwrap();
1086 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
1087 assert_eq!(msg, TestProtoMessage::message("good bye!"));
1088
1089 yield TestProtoMessage::message("good bye!").encoded();
1090
1091 futures::future::pending::<()>().await;
1092 unreachable!()
1093 }
1094 })
1095 .unwrap();
1096
1097 loop {
1098 let _ = st.next().await;
1099 }
1100 });
1101
1102 let conn = connect_passthrough(local_addr, test_hello().0).await;
1103 let (mut st, _their_status) = RlpxProtocolMultiplexer::new(conn)
1104 .into_eth_satellite_stream::<EthNetworkPrimitives>(
1105 status,
1106 fork_filter,
1107 Arc::new(EthHandshake::default()),
1108 MAX_MESSAGE_SIZE,
1109 )
1110 .await
1111 .unwrap();
1112
1113 let (tx, mut rx) = oneshot::channel();
1114
1115 st.install_protocol(&TestProtoMessage::capability(), |mut conn| {
1116 async_stream::stream! {
1117 let msg = conn.next().await.unwrap();
1118 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
1119 assert_eq!(msg, TestProtoMessage::ping());
1120
1121 yield TestProtoMessage::pong().encoded();
1122
1123 let msg = conn.next().await.unwrap();
1124 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
1125 assert_eq!(msg, TestProtoMessage::message("hello"));
1126
1127 yield TestProtoMessage::message("good bye!").encoded();
1128
1129 let msg = conn.next().await.unwrap();
1130 let msg = TestProtoMessage::decode_message(&mut &msg[..]).unwrap();
1131 assert_eq!(msg, TestProtoMessage::message("good bye!"));
1132
1133 tx.send(()).unwrap();
1134
1135 futures::future::pending::<()>().await;
1136 unreachable!()
1137 }
1138 })
1139 .unwrap();
1140
1141 loop {
1142 tokio::select! {
1143 _ = &mut rx => {
1144 break
1145 }
1146 _ = st.next() => {
1147 }
1148 }
1149 }
1150 }
1151}