reth_discv4/
test_utils.rs

1//! Mock discovery support
2
3use crate::{
4    proto::{FindNode, Message, Neighbours, NodeEndpoint, Packet, Ping, Pong},
5    receive_loop, send_loop, Discv4, Discv4Config, Discv4Service, EgressSender, IngressEvent,
6    IngressReceiver, PeerId, SAFE_MAX_DATAGRAM_NEIGHBOUR_RECORDS,
7};
8use alloy_primitives::{hex, B256};
9use rand::{thread_rng, Rng, RngCore};
10use reth_ethereum_forks::{ForkHash, ForkId};
11use reth_network_peers::{pk2id, NodeRecord};
12use secp256k1::{SecretKey, SECP256K1};
13use std::{
14    collections::{HashMap, HashSet},
15    io,
16    net::{IpAddr, SocketAddr},
17    pin::Pin,
18    str::FromStr,
19    sync::Arc,
20    task::{Context, Poll},
21    time::{Duration, SystemTime, UNIX_EPOCH},
22};
23use tokio::{
24    net::UdpSocket,
25    sync::mpsc,
26    task::{JoinHandle, JoinSet},
27};
28use tokio_stream::{Stream, StreamExt};
29use tracing::debug;
30
31/// Mock discovery node
32#[derive(Debug)]
33pub struct MockDiscovery {
34    local_addr: SocketAddr,
35    local_enr: NodeRecord,
36    secret_key: SecretKey,
37    _udp: Arc<UdpSocket>,
38    _tasks: JoinSet<()>,
39    /// Receiver for incoming messages
40    ingress: IngressReceiver,
41    /// Sender for sending outgoing messages
42    egress: EgressSender,
43    pending_pongs: HashSet<PeerId>,
44    pending_neighbours: HashMap<PeerId, Vec<NodeRecord>>,
45    command_rx: mpsc::Receiver<MockCommand>,
46}
47
48impl MockDiscovery {
49    /// Creates a new instance and opens a socket
50    pub async fn new() -> io::Result<(Self, mpsc::Sender<MockCommand>)> {
51        let mut rng = thread_rng();
52        let socket = SocketAddr::from_str("0.0.0.0:0").unwrap();
53        let (secret_key, pk) = SECP256K1.generate_keypair(&mut rng);
54        let id = pk2id(&pk);
55        let socket = Arc::new(UdpSocket::bind(socket).await?);
56        let local_addr = socket.local_addr()?;
57        let local_enr = NodeRecord {
58            address: local_addr.ip(),
59            tcp_port: local_addr.port(),
60            udp_port: local_addr.port(),
61            id,
62        };
63
64        let (ingress_tx, ingress_rx) = mpsc::channel(128);
65        let (egress_tx, egress_rx) = mpsc::channel(128);
66        let mut tasks = JoinSet::<()>::new();
67
68        let udp = Arc::clone(&socket);
69        tasks.spawn(receive_loop(udp, ingress_tx, local_enr.id));
70
71        let udp = Arc::clone(&socket);
72        tasks.spawn(send_loop(udp, egress_rx));
73
74        let (tx, command_rx) = mpsc::channel(128);
75        let this = Self {
76            _tasks: tasks,
77            ingress: ingress_rx,
78            egress: egress_tx,
79            local_addr,
80            local_enr,
81            secret_key,
82            _udp: socket,
83            pending_pongs: Default::default(),
84            pending_neighbours: Default::default(),
85            command_rx,
86        };
87        Ok((this, tx))
88    }
89
90    /// Spawn and consume the stream.
91    pub fn spawn(self) -> JoinHandle<()> {
92        tokio::task::spawn(async move {
93            let _: Vec<_> = self.collect().await;
94        })
95    }
96
97    /// Queue a pending pong.
98    pub fn queue_pong(&mut self, from: PeerId) {
99        self.pending_pongs.insert(from);
100    }
101
102    /// Queue a pending Neighbours response.
103    pub fn queue_neighbours(&mut self, target: PeerId, nodes: Vec<NodeRecord>) {
104        self.pending_neighbours.insert(target, nodes);
105    }
106
107    /// Returns the local socket address associated with the service.
108    pub const fn local_addr(&self) -> SocketAddr {
109        self.local_addr
110    }
111
112    /// Returns the local [`NodeRecord`] associated with the service.
113    pub const fn local_enr(&self) -> NodeRecord {
114        self.local_enr
115    }
116
117    /// Encodes the packet, sends it and returns the hash.
118    fn send_packet(&self, msg: Message, to: SocketAddr) -> B256 {
119        let (payload, hash) = msg.encode(&self.secret_key);
120        let _ = self.egress.try_send((payload, to));
121        hash
122    }
123
124    fn send_neighbours_timeout(&self) -> u64 {
125        (SystemTime::now().duration_since(UNIX_EPOCH).unwrap() + Duration::from_secs(30)).as_secs()
126    }
127}
128
129impl Stream for MockDiscovery {
130    type Item = MockEvent;
131
132    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
133        let this = self.get_mut();
134        // process all incoming commands
135        while let Poll::Ready(maybe_cmd) = this.command_rx.poll_recv(cx) {
136            let Some(cmd) = maybe_cmd else { return Poll::Ready(None) };
137            match cmd {
138                MockCommand::MockPong { node_id } => {
139                    this.queue_pong(node_id);
140                }
141                MockCommand::MockNeighbours { target, nodes } => {
142                    this.queue_neighbours(target, nodes);
143                }
144            }
145        }
146
147        while let Poll::Ready(Some(event)) = this.ingress.poll_recv(cx) {
148            match event {
149                IngressEvent::RecvError(_) => {}
150                IngressEvent::BadPacket(from, err, data) => {
151                    debug!(target: "discv4", ?from, %err, packet=?hex::encode(&data), "bad packet");
152                }
153                IngressEvent::Packet(remote_addr, Packet { msg, node_id, hash }) => match msg {
154                    Message::Ping(ping) => {
155                        if this.pending_pongs.remove(&node_id) {
156                            let pong = Pong {
157                                to: ping.from,
158                                echo: hash,
159                                expire: ping.expire,
160                                enr_sq: None,
161                            };
162                            let msg = Message::Pong(pong.clone());
163                            this.send_packet(msg, remote_addr);
164                            return Poll::Ready(Some(MockEvent::Pong {
165                                ping,
166                                pong,
167                                to: remote_addr,
168                            }))
169                        }
170                    }
171                    Message::Pong(_) | Message::Neighbours(_) => {}
172                    Message::FindNode(msg) => {
173                        if let Some(nodes) = this.pending_neighbours.remove(&msg.id) {
174                            let msg = Message::Neighbours(Neighbours {
175                                nodes: nodes.clone(),
176                                expire: this.send_neighbours_timeout(),
177                            });
178                            this.send_packet(msg, remote_addr);
179                            return Poll::Ready(Some(MockEvent::Neighbours {
180                                nodes,
181                                to: remote_addr,
182                            }))
183                        }
184                    }
185                    Message::EnrRequest(_) | Message::EnrResponse(_) => todo!(),
186                },
187            }
188        }
189
190        Poll::Pending
191    }
192}
193
194/// Represents the event types produced by the mock service.
195#[derive(Debug)]
196pub enum MockEvent {
197    /// A Pong event, consisting of the original Ping packet, the corresponding Pong packet,
198    /// and the recipient's socket address.
199    Pong {
200        /// The original Ping packet.
201        ping: Ping,
202        /// The corresponding Pong packet.
203        pong: Pong,
204        /// The recipient's socket address.
205        to: SocketAddr,
206    },
207    /// A Neighbours event, containing a list of node records and the recipient's socket address.
208    Neighbours {
209        /// The list of node records.
210        nodes: Vec<NodeRecord>,
211        /// The recipient's socket address.
212        to: SocketAddr,
213    },
214}
215
216/// Represents commands for interacting with the `MockDiscovery` service.
217#[derive(Debug)]
218pub enum MockCommand {
219    /// A command to simulate a Pong event, including the node ID of the recipient.
220    MockPong {
221        /// The node ID of the recipient.
222        node_id: PeerId,
223    },
224    /// A command to simulate a Neighbours event, including the target node ID and a list of node
225    /// records.
226    MockNeighbours {
227        /// The target node ID.
228        target: PeerId,
229        /// The list of node records.
230        nodes: Vec<NodeRecord>,
231    },
232}
233
234/// Creates a new testing instance for [`Discv4`] and its service
235pub async fn create_discv4() -> (Discv4, Discv4Service) {
236    let fork_id = ForkId { hash: ForkHash(hex!("743f3d89")), next: 16191202 };
237    create_discv4_with_config(Discv4Config::builder().add_eip868_pair("eth", fork_id).build()).await
238}
239
240/// Creates a new testing instance for [`Discv4`] and its service with the given config.
241pub async fn create_discv4_with_config(config: Discv4Config) -> (Discv4, Discv4Service) {
242    let mut rng = thread_rng();
243    let socket = SocketAddr::from_str("0.0.0.0:0").unwrap();
244    let (secret_key, pk) = SECP256K1.generate_keypair(&mut rng);
245    let id = pk2id(&pk);
246    let local_enr =
247        NodeRecord { address: socket.ip(), tcp_port: socket.port(), udp_port: socket.port(), id };
248    Discv4::bind(socket, local_enr, secret_key, config).await.unwrap()
249}
250
251/// Generates a random [`NodeEndpoint`] using the provided random number generator.
252pub fn rng_endpoint(rng: &mut impl Rng) -> NodeEndpoint {
253    let address = if rng.gen() {
254        let mut ip = [0u8; 4];
255        rng.fill_bytes(&mut ip);
256        IpAddr::V4(ip.into())
257    } else {
258        let mut ip = [0u8; 16];
259        rng.fill_bytes(&mut ip);
260        IpAddr::V6(ip.into())
261    };
262    NodeEndpoint { address, tcp_port: rng.gen(), udp_port: rng.gen() }
263}
264
265/// Generates a random [`NodeRecord`] using the provided random number generator.
266pub fn rng_record(rng: &mut impl RngCore) -> NodeRecord {
267    let NodeEndpoint { address, udp_port, tcp_port } = rng_endpoint(rng);
268    NodeRecord { address, tcp_port, udp_port, id: rng.gen() }
269}
270
271/// Generates a random IPv6 [`NodeRecord`] using the provided random number generator.
272pub fn rng_ipv6_record(rng: &mut impl RngCore) -> NodeRecord {
273    let mut ip = [0u8; 16];
274    rng.fill_bytes(&mut ip);
275    let address = IpAddr::V6(ip.into());
276    NodeRecord { address, tcp_port: rng.gen(), udp_port: rng.gen(), id: rng.gen() }
277}
278
279/// Generates a random IPv4 [`NodeRecord`] using the provided random number generator.
280pub fn rng_ipv4_record(rng: &mut impl RngCore) -> NodeRecord {
281    let mut ip = [0u8; 4];
282    rng.fill_bytes(&mut ip);
283    let address = IpAddr::V4(ip.into());
284    NodeRecord { address, tcp_port: rng.gen(), udp_port: rng.gen(), id: rng.gen() }
285}
286
287/// Generates a random [`Message`] using the provided random number generator.
288pub fn rng_message(rng: &mut impl RngCore) -> Message {
289    match rng.gen_range(1..=4) {
290        1 => Message::Ping(Ping {
291            from: rng_endpoint(rng),
292            to: rng_endpoint(rng),
293            expire: rng.gen(),
294            enr_sq: None,
295        }),
296        2 => Message::Pong(Pong {
297            to: rng_endpoint(rng),
298            echo: rng.gen(),
299            expire: rng.gen(),
300            enr_sq: None,
301        }),
302        3 => Message::FindNode(FindNode { id: rng.gen(), expire: rng.gen() }),
303        4 => {
304            let num: usize = rng.gen_range(1..=SAFE_MAX_DATAGRAM_NEIGHBOUR_RECORDS);
305            Message::Neighbours(Neighbours {
306                nodes: std::iter::repeat_with(|| rng_record(rng)).take(num).collect(),
307                expire: rng.gen(),
308            })
309        }
310        _ => unreachable!(),
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use crate::Discv4Event;
318    use std::net::Ipv4Addr;
319
320    /// This test creates two local UDP sockets. The mocked discovery service responds to specific
321    /// messages and we check the actual service receives answers
322    #[tokio::test]
323    async fn can_mock_discovery() {
324        reth_tracing::init_test_tracing();
325
326        let mut rng = thread_rng();
327        let (_, mut service) = create_discv4().await;
328        let (mut mockv4, _cmd) = MockDiscovery::new().await.unwrap();
329
330        let mock_enr = mockv4.local_enr();
331
332        // we only want to test internally
333        service.local_enr_mut().address = IpAddr::V4(Ipv4Addr::UNSPECIFIED);
334
335        let discv_addr = service.local_addr();
336        let discv_enr = service.local_enr();
337
338        // make sure it responds with a Pong
339        mockv4.queue_pong(discv_enr.id);
340
341        // This sends a ping to the mock service
342        service.add_node(mock_enr);
343
344        // process the mock pong
345        let event = mockv4.next().await.unwrap();
346        match event {
347            MockEvent::Pong { ping: _, pong: _, to } => {
348                assert_eq!(to, SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), discv_addr.port()));
349            }
350            MockEvent::Neighbours { .. } => {
351                unreachable!("invalid response")
352            }
353        }
354
355        // discovery service received mocked pong
356        let event = service.next().await.unwrap();
357        assert_eq!(event, Discv4Event::Pong);
358
359        assert!(service.contains_node(mock_enr.id));
360
361        let mock_nodes =
362            std::iter::repeat_with(|| rng_record(&mut rng)).take(5).collect::<Vec<_>>();
363
364        mockv4.queue_neighbours(discv_enr.id, mock_nodes.clone());
365
366        // start lookup
367        service.lookup_self();
368
369        let event = mockv4.next().await.unwrap();
370        match event {
371            MockEvent::Pong { .. } => {
372                unreachable!("invalid response")
373            }
374            MockEvent::Neighbours { nodes, to } => {
375                assert_eq!(to, SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), discv_addr.port()));
376                assert_eq!(nodes, mock_nodes);
377            }
378        }
379
380        // discovery service received mocked pong
381        let event = service.next().await.unwrap();
382        assert_eq!(event, Discv4Event::Neighbours);
383    }
384}