reth_discv4/
test_utils.rs

1//! Mock discovery support
2
3// TODO(rand): update ::random calls after rand_09 migration
4
5use crate::{
6    proto::{FindNode, Message, Neighbours, NodeEndpoint, Packet, Ping, Pong},
7    receive_loop, send_loop, Discv4, Discv4Config, Discv4Service, EgressSender, IngressEvent,
8    IngressReceiver, PeerId, SAFE_MAX_DATAGRAM_NEIGHBOUR_RECORDS,
9};
10use alloy_primitives::{hex, B256, B512};
11use rand_08::{thread_rng, Rng, RngCore};
12use reth_ethereum_forks::{ForkHash, ForkId};
13use reth_network_peers::{pk2id, NodeRecord};
14use secp256k1::{SecretKey, SECP256K1};
15use std::{
16    collections::{HashMap, HashSet},
17    io,
18    net::{IpAddr, SocketAddr},
19    pin::Pin,
20    str::FromStr,
21    sync::Arc,
22    task::{Context, Poll},
23    time::{Duration, SystemTime, UNIX_EPOCH},
24};
25use tokio::{
26    net::UdpSocket,
27    sync::mpsc,
28    task::{JoinHandle, JoinSet},
29};
30use tokio_stream::{Stream, StreamExt};
31use tracing::debug;
32
33/// Mock discovery node
34#[derive(Debug)]
35pub struct MockDiscovery {
36    local_addr: SocketAddr,
37    local_enr: NodeRecord,
38    secret_key: SecretKey,
39    _udp: Arc<UdpSocket>,
40    _tasks: JoinSet<()>,
41    /// Receiver for incoming messages
42    ingress: IngressReceiver,
43    /// Sender for sending outgoing messages
44    egress: EgressSender,
45    pending_pongs: HashSet<PeerId>,
46    pending_neighbours: HashMap<PeerId, Vec<NodeRecord>>,
47    command_rx: mpsc::Receiver<MockCommand>,
48}
49
50impl MockDiscovery {
51    /// Creates a new instance and opens a socket
52    pub async fn new() -> io::Result<(Self, mpsc::Sender<MockCommand>)> {
53        let mut rng = thread_rng();
54        let socket = SocketAddr::from_str("0.0.0.0:0").unwrap();
55        let (secret_key, pk) = SECP256K1.generate_keypair(&mut rng);
56        let id = pk2id(&pk);
57        let socket = Arc::new(UdpSocket::bind(socket).await?);
58        let local_addr = socket.local_addr()?;
59        let local_enr = NodeRecord {
60            address: local_addr.ip(),
61            tcp_port: local_addr.port(),
62            udp_port: local_addr.port(),
63            id,
64        };
65
66        let (ingress_tx, ingress_rx) = mpsc::channel(128);
67        let (egress_tx, egress_rx) = mpsc::channel(128);
68        let mut tasks = JoinSet::<()>::new();
69
70        let udp = Arc::clone(&socket);
71        tasks.spawn(receive_loop(udp, ingress_tx, local_enr.id));
72
73        let udp = Arc::clone(&socket);
74        tasks.spawn(send_loop(udp, egress_rx));
75
76        let (tx, command_rx) = mpsc::channel(128);
77        let this = Self {
78            _tasks: tasks,
79            ingress: ingress_rx,
80            egress: egress_tx,
81            local_addr,
82            local_enr,
83            secret_key,
84            _udp: socket,
85            pending_pongs: Default::default(),
86            pending_neighbours: Default::default(),
87            command_rx,
88        };
89        Ok((this, tx))
90    }
91
92    /// Spawn and consume the stream.
93    pub fn spawn(self) -> JoinHandle<()> {
94        tokio::task::spawn(async move {
95            let _: Vec<_> = self.collect().await;
96        })
97    }
98
99    /// Queue a pending pong.
100    pub fn queue_pong(&mut self, from: PeerId) {
101        self.pending_pongs.insert(from);
102    }
103
104    /// Queue a pending Neighbours response.
105    pub fn queue_neighbours(&mut self, target: PeerId, nodes: Vec<NodeRecord>) {
106        self.pending_neighbours.insert(target, nodes);
107    }
108
109    /// Returns the local socket address associated with the service.
110    pub const fn local_addr(&self) -> SocketAddr {
111        self.local_addr
112    }
113
114    /// Returns the local [`NodeRecord`] associated with the service.
115    pub const fn local_enr(&self) -> NodeRecord {
116        self.local_enr
117    }
118
119    /// Encodes the packet, sends it and returns the hash.
120    fn send_packet(&self, msg: Message, to: SocketAddr) -> B256 {
121        let (payload, hash) = msg.encode(&self.secret_key);
122        let _ = self.egress.try_send((payload, to));
123        hash
124    }
125
126    fn send_neighbours_timeout(&self) -> u64 {
127        (SystemTime::now().duration_since(UNIX_EPOCH).unwrap() + Duration::from_secs(30)).as_secs()
128    }
129}
130
131impl Stream for MockDiscovery {
132    type Item = MockEvent;
133
134    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
135        let this = self.get_mut();
136        // process all incoming commands
137        while let Poll::Ready(maybe_cmd) = this.command_rx.poll_recv(cx) {
138            let Some(cmd) = maybe_cmd else { return Poll::Ready(None) };
139            match cmd {
140                MockCommand::MockPong { node_id } => {
141                    this.queue_pong(node_id);
142                }
143                MockCommand::MockNeighbours { target, nodes } => {
144                    this.queue_neighbours(target, nodes);
145                }
146            }
147        }
148
149        while let Poll::Ready(Some(event)) = this.ingress.poll_recv(cx) {
150            match event {
151                IngressEvent::RecvError(_) => {}
152                IngressEvent::BadPacket(from, err, data) => {
153                    debug!(target: "discv4", ?from, %err, packet=?hex::encode(&data), "bad packet");
154                }
155                IngressEvent::Packet(remote_addr, Packet { msg, node_id, hash }) => match msg {
156                    Message::Ping(ping) => {
157                        if this.pending_pongs.remove(&node_id) {
158                            let pong = Pong {
159                                to: ping.from,
160                                echo: hash,
161                                expire: ping.expire,
162                                enr_sq: None,
163                            };
164                            let msg = Message::Pong(pong.clone());
165                            this.send_packet(msg, remote_addr);
166                            return Poll::Ready(Some(MockEvent::Pong {
167                                ping,
168                                pong,
169                                to: remote_addr,
170                            }))
171                        }
172                    }
173                    Message::Pong(_) | Message::Neighbours(_) => {}
174                    Message::FindNode(msg) => {
175                        if let Some(nodes) = this.pending_neighbours.remove(&msg.id) {
176                            let msg = Message::Neighbours(Neighbours {
177                                nodes: nodes.clone(),
178                                expire: this.send_neighbours_timeout(),
179                            });
180                            this.send_packet(msg, remote_addr);
181                            return Poll::Ready(Some(MockEvent::Neighbours {
182                                nodes,
183                                to: remote_addr,
184                            }))
185                        }
186                    }
187                    Message::EnrRequest(_) | Message::EnrResponse(_) => todo!(),
188                },
189            }
190        }
191
192        Poll::Pending
193    }
194}
195
196/// Represents the event types produced by the mock service.
197#[derive(Debug)]
198pub enum MockEvent {
199    /// A Pong event, consisting of the original Ping packet, the corresponding Pong packet,
200    /// and the recipient's socket address.
201    Pong {
202        /// The original Ping packet.
203        ping: Ping,
204        /// The corresponding Pong packet.
205        pong: Pong,
206        /// The recipient's socket address.
207        to: SocketAddr,
208    },
209    /// A Neighbours event, containing a list of node records and the recipient's socket address.
210    Neighbours {
211        /// The list of node records.
212        nodes: Vec<NodeRecord>,
213        /// The recipient's socket address.
214        to: SocketAddr,
215    },
216}
217
218/// Represents commands for interacting with the `MockDiscovery` service.
219#[derive(Debug)]
220pub enum MockCommand {
221    /// A command to simulate a Pong event, including the node ID of the recipient.
222    MockPong {
223        /// The node ID of the recipient.
224        node_id: PeerId,
225    },
226    /// A command to simulate a Neighbours event, including the target node ID and a list of node
227    /// records.
228    MockNeighbours {
229        /// The target node ID.
230        target: PeerId,
231        /// The list of node records.
232        nodes: Vec<NodeRecord>,
233    },
234}
235
236/// Creates a new testing instance for [`Discv4`] and its service
237pub async fn create_discv4() -> (Discv4, Discv4Service) {
238    let fork_id = ForkId { hash: ForkHash(hex!("743f3d89")), next: 16191202 };
239    create_discv4_with_config(Discv4Config::builder().add_eip868_pair("eth", fork_id).build()).await
240}
241
242/// Creates a new testing instance for [`Discv4`] and its service with the given config.
243pub async fn create_discv4_with_config(config: Discv4Config) -> (Discv4, Discv4Service) {
244    let mut rng = thread_rng();
245    let socket = SocketAddr::from_str("0.0.0.0:0").unwrap();
246    let (secret_key, pk) = SECP256K1.generate_keypair(&mut rng);
247    let id = pk2id(&pk);
248    let local_enr =
249        NodeRecord { address: socket.ip(), tcp_port: socket.port(), udp_port: socket.port(), id };
250    Discv4::bind(socket, local_enr, secret_key, config).await.unwrap()
251}
252
253/// Generates a random [`NodeEndpoint`] using the provided random number generator.
254pub fn rng_endpoint(rng: &mut impl Rng) -> NodeEndpoint {
255    let address = if rng.gen() {
256        let mut ip = [0u8; 4];
257        rng.fill_bytes(&mut ip);
258        IpAddr::V4(ip.into())
259    } else {
260        let mut ip = [0u8; 16];
261        rng.fill_bytes(&mut ip);
262        IpAddr::V6(ip.into())
263    };
264    NodeEndpoint { address, tcp_port: rng.gen(), udp_port: rng.gen() }
265}
266
267/// Generates a random [`NodeRecord`] using the provided random number generator.
268pub fn rng_record(rng: &mut impl RngCore) -> NodeRecord {
269    let NodeEndpoint { address, udp_port, tcp_port } = rng_endpoint(rng);
270    // TODO(rand)
271    NodeRecord { address, tcp_port, udp_port, id: B512::random() }
272}
273
274/// Generates a random IPv6 [`NodeRecord`] using the provided random number generator.
275pub fn rng_ipv6_record(rng: &mut impl RngCore) -> NodeRecord {
276    let mut ip = [0u8; 16];
277    rng.fill_bytes(&mut ip);
278    let address = IpAddr::V6(ip.into());
279    // TODO(rand)
280    NodeRecord { address, tcp_port: rng.gen(), udp_port: rng.gen(), id: B512::random() }
281}
282
283/// Generates a random IPv4 [`NodeRecord`] using the provided random number generator.
284pub fn rng_ipv4_record(rng: &mut impl RngCore) -> NodeRecord {
285    let mut ip = [0u8; 4];
286    rng.fill_bytes(&mut ip);
287    let address = IpAddr::V4(ip.into());
288    // TODO(rand)
289    NodeRecord { address, tcp_port: rng.gen(), udp_port: rng.gen(), id: B512::random() }
290}
291
292/// Generates a random [`Message`] using the provided random number generator.
293pub fn rng_message(rng: &mut impl RngCore) -> Message {
294    match rng.gen_range(1..=4) {
295        1 => Message::Ping(Ping {
296            from: rng_endpoint(rng),
297            to: rng_endpoint(rng),
298            expire: rng.gen(),
299            enr_sq: None,
300        }),
301        2 => Message::Pong(Pong {
302            to: rng_endpoint(rng),
303            echo: B256::random(),
304            expire: rng.gen(),
305            enr_sq: None,
306        }),
307        3 => Message::FindNode(FindNode { id: B512::random(), expire: rng.gen() }),
308        4 => {
309            let num: usize = rng.gen_range(1..=SAFE_MAX_DATAGRAM_NEIGHBOUR_RECORDS);
310            Message::Neighbours(Neighbours {
311                nodes: std::iter::repeat_with(|| rng_record(rng)).take(num).collect(),
312                expire: rng.gen(),
313            })
314        }
315        _ => unreachable!(),
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use crate::Discv4Event;
323    use std::net::Ipv4Addr;
324
325    /// This test creates two local UDP sockets. The mocked discovery service responds to specific
326    /// messages and we check the actual service receives answers
327    #[tokio::test]
328    async fn can_mock_discovery() {
329        reth_tracing::init_test_tracing();
330
331        let mut rng = thread_rng();
332        let (_, mut service) = create_discv4().await;
333        let (mut mockv4, _cmd) = MockDiscovery::new().await.unwrap();
334
335        let mock_enr = mockv4.local_enr();
336
337        // we only want to test internally
338        service.local_enr_mut().address = IpAddr::V4(Ipv4Addr::UNSPECIFIED);
339
340        let discv_addr = service.local_addr();
341        let discv_enr = service.local_enr();
342
343        // make sure it responds with a Pong
344        mockv4.queue_pong(discv_enr.id);
345
346        // This sends a ping to the mock service
347        service.add_node(mock_enr);
348
349        // process the mock pong
350        let event = mockv4.next().await.unwrap();
351        match event {
352            MockEvent::Pong { ping: _, pong: _, to } => {
353                assert_eq!(to, SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), discv_addr.port()));
354            }
355            MockEvent::Neighbours { .. } => {
356                unreachable!("invalid response")
357            }
358        }
359
360        // discovery service received mocked pong
361        let event = service.next().await.unwrap();
362        assert_eq!(event, Discv4Event::Pong);
363
364        assert!(service.contains_node(mock_enr.id));
365
366        let mock_nodes =
367            std::iter::repeat_with(|| rng_record(&mut rng)).take(5).collect::<Vec<_>>();
368
369        mockv4.queue_neighbours(discv_enr.id, mock_nodes.clone());
370
371        // start lookup
372        service.lookup_self();
373
374        let event = mockv4.next().await.unwrap();
375        match event {
376            MockEvent::Pong { .. } => {
377                unreachable!("invalid response")
378            }
379            MockEvent::Neighbours { nodes, to } => {
380                assert_eq!(to, SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), discv_addr.port()));
381                assert_eq!(nodes, mock_nodes);
382            }
383        }
384
385        // discovery service received mocked pong
386        let event = service.next().await.unwrap();
387        assert_eq!(event, Discv4Event::Neighbours);
388    }
389}