reth_eth_wire/
disconnect.rs

1//! Disconnect
2
3use std::{future::Future, pin::Pin};
4
5use futures::{Sink, SinkExt};
6use reth_ecies::stream::ECIESStream;
7use reth_eth_wire_types::DisconnectReason;
8use tokio::io::AsyncWrite;
9use tokio_util::codec::{Encoder, Framed};
10
11type DisconnectResult<E> = Result<(), E>;
12
13/// This trait is meant to allow higher level protocols like `eth` to disconnect from a peer, using
14/// lower-level disconnect functions (such as those that exist in the `p2p` protocol) if the
15/// underlying stream supports it.
16pub trait CanDisconnect<T>: Sink<T> + Unpin {
17    /// Disconnects from the underlying stream, using a [`DisconnectReason`] as disconnect
18    /// information if the stream implements a protocol that can carry the additional disconnect
19    /// metadata.
20    fn disconnect(
21        &mut self,
22        reason: DisconnectReason,
23    ) -> Pin<Box<dyn Future<Output = DisconnectResult<Self::Error>> + Send + '_>>;
24}
25
26// basic impls for things like Framed<TcpStream, etc>
27impl<T, I, U> CanDisconnect<I> for Framed<T, U>
28where
29    T: AsyncWrite + Unpin + Send,
30    U: Encoder<I> + Send,
31{
32    fn disconnect(
33        &mut self,
34        _reason: DisconnectReason,
35    ) -> Pin<Box<dyn Future<Output = Result<(), <Self as Sink<I>>::Error>> + Send + '_>> {
36        Box::pin(async move { self.close().await })
37    }
38}
39
40impl<S> CanDisconnect<bytes::Bytes> for ECIESStream<S>
41where
42    S: AsyncWrite + Unpin + Send,
43{
44    fn disconnect(
45        &mut self,
46        _reason: DisconnectReason,
47    ) -> Pin<Box<dyn Future<Output = Result<(), std::io::Error>> + Send + '_>> {
48        Box::pin(async move { self.close().await })
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use crate::{p2pstream::P2PMessage, DisconnectReason};
55    use alloy_primitives::hex;
56    use alloy_rlp::{Decodable, Encodable};
57
58    fn all_reasons() -> Vec<DisconnectReason> {
59        vec![
60            DisconnectReason::DisconnectRequested,
61            DisconnectReason::TcpSubsystemError,
62            DisconnectReason::ProtocolBreach,
63            DisconnectReason::UselessPeer,
64            DisconnectReason::TooManyPeers,
65            DisconnectReason::AlreadyConnected,
66            DisconnectReason::IncompatibleP2PProtocolVersion,
67            DisconnectReason::NullNodeIdentity,
68            DisconnectReason::ClientQuitting,
69            DisconnectReason::UnexpectedHandshakeIdentity,
70            DisconnectReason::ConnectedToSelf,
71            DisconnectReason::PingTimeout,
72            DisconnectReason::SubprotocolSpecific,
73        ]
74    }
75
76    #[test]
77    fn disconnect_round_trip() {
78        let all_reasons = all_reasons();
79
80        for reason in all_reasons {
81            let disconnect = P2PMessage::Disconnect(reason);
82
83            let mut disconnect_encoded = Vec::new();
84            disconnect.encode(&mut disconnect_encoded);
85
86            let disconnect_decoded = P2PMessage::decode(&mut &disconnect_encoded[..]).unwrap();
87
88            assert_eq!(disconnect, disconnect_decoded);
89        }
90    }
91
92    #[test]
93    fn test_reason_too_short() {
94        assert!(DisconnectReason::decode(&mut &[0u8; 0][..]).is_err())
95    }
96
97    #[test]
98    fn test_reason_too_long() {
99        assert!(DisconnectReason::decode(&mut &[0u8; 3][..]).is_err())
100    }
101
102    #[test]
103    fn test_reason_zero_length_list() {
104        let list_with_zero_length = hex::decode("c000").unwrap();
105        let res = DisconnectReason::decode(&mut &list_with_zero_length[..]);
106        assert!(res.is_err());
107        assert_eq!(res.unwrap_err().to_string(), "unexpected list length (got 0, expected 1)")
108    }
109
110    #[test]
111    fn disconnect_encoding_length() {
112        let all_reasons = all_reasons();
113
114        for reason in all_reasons {
115            let disconnect = P2PMessage::Disconnect(reason);
116
117            let mut disconnect_encoded = Vec::new();
118            disconnect.encode(&mut disconnect_encoded);
119
120            assert_eq!(disconnect_encoded.len(), disconnect.length());
121        }
122    }
123
124    #[test]
125    fn test_decode_known_reasons() {
126        let all_reasons = vec![
127            // encoding the disconnect reason as a single byte
128            "0100", // 0x00 case
129            "0180", // second 0x00 case
130            "0101", "0102", "0103", "0104", "0105", "0106", "0107", "0108", "0109", "010a", "010b",
131            "0110",   // encoding the disconnect reason in a list
132            "01c100", // 0x00 case
133            "01c180", // second 0x00 case
134            "01c101", "01c102", "01c103", "01c104", "01c105", "01c106", "01c107", "01c108",
135            "01c109", "01c10a", "01c10b", "01c110",
136        ];
137
138        for reason in all_reasons {
139            let reason = hex::decode(reason).unwrap();
140            let message = P2PMessage::decode(&mut &reason[..]).unwrap();
141            let P2PMessage::Disconnect(_) = message else {
142                panic!("expected a disconnect message");
143            };
144        }
145    }
146
147    #[test]
148    fn test_decode_disconnect_requested() {
149        let reason = "0100";
150        let reason = hex::decode(reason).unwrap();
151        match P2PMessage::decode(&mut &reason[..]).unwrap() {
152            P2PMessage::Disconnect(DisconnectReason::DisconnectRequested) => {}
153            _ => {
154                unreachable!()
155            }
156        }
157    }
158}