1use crate::{
4 proto::{FindNode, Message, Neighbours, NodeEndpoint, Packet, Ping, Pong},
5 receive_loop, send_loop, Discv4, Discv4Config, Discv4Service, EgressSender, IngressEvent,
6 IngressReceiver, PeerId, SAFE_MAX_DATAGRAM_NEIGHBOUR_RECORDS,
7};
8use alloy_primitives::{hex, B256};
9use rand::{thread_rng, Rng, RngCore};
10use reth_ethereum_forks::{ForkHash, ForkId};
11use reth_network_peers::{pk2id, NodeRecord};
12use secp256k1::{SecretKey, SECP256K1};
13use std::{
14 collections::{HashMap, HashSet},
15 io,
16 net::{IpAddr, SocketAddr},
17 pin::Pin,
18 str::FromStr,
19 sync::Arc,
20 task::{Context, Poll},
21 time::{Duration, SystemTime, UNIX_EPOCH},
22};
23use tokio::{
24 net::UdpSocket,
25 sync::mpsc,
26 task::{JoinHandle, JoinSet},
27};
28use tokio_stream::{Stream, StreamExt};
29use tracing::debug;
30
31#[derive(Debug)]
33pub struct MockDiscovery {
34 local_addr: SocketAddr,
35 local_enr: NodeRecord,
36 secret_key: SecretKey,
37 _udp: Arc<UdpSocket>,
38 _tasks: JoinSet<()>,
39 ingress: IngressReceiver,
41 egress: EgressSender,
43 pending_pongs: HashSet<PeerId>,
44 pending_neighbours: HashMap<PeerId, Vec<NodeRecord>>,
45 command_rx: mpsc::Receiver<MockCommand>,
46}
47
48impl MockDiscovery {
49 pub async fn new() -> io::Result<(Self, mpsc::Sender<MockCommand>)> {
51 let mut rng = thread_rng();
52 let socket = SocketAddr::from_str("0.0.0.0:0").unwrap();
53 let (secret_key, pk) = SECP256K1.generate_keypair(&mut rng);
54 let id = pk2id(&pk);
55 let socket = Arc::new(UdpSocket::bind(socket).await?);
56 let local_addr = socket.local_addr()?;
57 let local_enr = NodeRecord {
58 address: local_addr.ip(),
59 tcp_port: local_addr.port(),
60 udp_port: local_addr.port(),
61 id,
62 };
63
64 let (ingress_tx, ingress_rx) = mpsc::channel(128);
65 let (egress_tx, egress_rx) = mpsc::channel(128);
66 let mut tasks = JoinSet::<()>::new();
67
68 let udp = Arc::clone(&socket);
69 tasks.spawn(receive_loop(udp, ingress_tx, local_enr.id));
70
71 let udp = Arc::clone(&socket);
72 tasks.spawn(send_loop(udp, egress_rx));
73
74 let (tx, command_rx) = mpsc::channel(128);
75 let this = Self {
76 _tasks: tasks,
77 ingress: ingress_rx,
78 egress: egress_tx,
79 local_addr,
80 local_enr,
81 secret_key,
82 _udp: socket,
83 pending_pongs: Default::default(),
84 pending_neighbours: Default::default(),
85 command_rx,
86 };
87 Ok((this, tx))
88 }
89
90 pub fn spawn(self) -> JoinHandle<()> {
92 tokio::task::spawn(async move {
93 let _: Vec<_> = self.collect().await;
94 })
95 }
96
97 pub fn queue_pong(&mut self, from: PeerId) {
99 self.pending_pongs.insert(from);
100 }
101
102 pub fn queue_neighbours(&mut self, target: PeerId, nodes: Vec<NodeRecord>) {
104 self.pending_neighbours.insert(target, nodes);
105 }
106
107 pub const fn local_addr(&self) -> SocketAddr {
109 self.local_addr
110 }
111
112 pub const fn local_enr(&self) -> NodeRecord {
114 self.local_enr
115 }
116
117 fn send_packet(&self, msg: Message, to: SocketAddr) -> B256 {
119 let (payload, hash) = msg.encode(&self.secret_key);
120 let _ = self.egress.try_send((payload, to));
121 hash
122 }
123
124 fn send_neighbours_timeout(&self) -> u64 {
125 (SystemTime::now().duration_since(UNIX_EPOCH).unwrap() + Duration::from_secs(30)).as_secs()
126 }
127}
128
129impl Stream for MockDiscovery {
130 type Item = MockEvent;
131
132 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
133 let this = self.get_mut();
134 while let Poll::Ready(maybe_cmd) = this.command_rx.poll_recv(cx) {
136 let Some(cmd) = maybe_cmd else { return Poll::Ready(None) };
137 match cmd {
138 MockCommand::MockPong { node_id } => {
139 this.queue_pong(node_id);
140 }
141 MockCommand::MockNeighbours { target, nodes } => {
142 this.queue_neighbours(target, nodes);
143 }
144 }
145 }
146
147 while let Poll::Ready(Some(event)) = this.ingress.poll_recv(cx) {
148 match event {
149 IngressEvent::RecvError(_) => {}
150 IngressEvent::BadPacket(from, err, data) => {
151 debug!(target: "discv4", ?from, %err, packet=?hex::encode(&data), "bad packet");
152 }
153 IngressEvent::Packet(remote_addr, Packet { msg, node_id, hash }) => match msg {
154 Message::Ping(ping) => {
155 if this.pending_pongs.remove(&node_id) {
156 let pong = Pong {
157 to: ping.from,
158 echo: hash,
159 expire: ping.expire,
160 enr_sq: None,
161 };
162 let msg = Message::Pong(pong.clone());
163 this.send_packet(msg, remote_addr);
164 return Poll::Ready(Some(MockEvent::Pong {
165 ping,
166 pong,
167 to: remote_addr,
168 }))
169 }
170 }
171 Message::Pong(_) | Message::Neighbours(_) => {}
172 Message::FindNode(msg) => {
173 if let Some(nodes) = this.pending_neighbours.remove(&msg.id) {
174 let msg = Message::Neighbours(Neighbours {
175 nodes: nodes.clone(),
176 expire: this.send_neighbours_timeout(),
177 });
178 this.send_packet(msg, remote_addr);
179 return Poll::Ready(Some(MockEvent::Neighbours {
180 nodes,
181 to: remote_addr,
182 }))
183 }
184 }
185 Message::EnrRequest(_) | Message::EnrResponse(_) => todo!(),
186 },
187 }
188 }
189
190 Poll::Pending
191 }
192}
193
194#[derive(Debug)]
196pub enum MockEvent {
197 Pong {
200 ping: Ping,
202 pong: Pong,
204 to: SocketAddr,
206 },
207 Neighbours {
209 nodes: Vec<NodeRecord>,
211 to: SocketAddr,
213 },
214}
215
216#[derive(Debug)]
218pub enum MockCommand {
219 MockPong {
221 node_id: PeerId,
223 },
224 MockNeighbours {
227 target: PeerId,
229 nodes: Vec<NodeRecord>,
231 },
232}
233
234pub async fn create_discv4() -> (Discv4, Discv4Service) {
236 let fork_id = ForkId { hash: ForkHash(hex!("743f3d89")), next: 16191202 };
237 create_discv4_with_config(Discv4Config::builder().add_eip868_pair("eth", fork_id).build()).await
238}
239
240pub async fn create_discv4_with_config(config: Discv4Config) -> (Discv4, Discv4Service) {
242 let mut rng = thread_rng();
243 let socket = SocketAddr::from_str("0.0.0.0:0").unwrap();
244 let (secret_key, pk) = SECP256K1.generate_keypair(&mut rng);
245 let id = pk2id(&pk);
246 let local_enr =
247 NodeRecord { address: socket.ip(), tcp_port: socket.port(), udp_port: socket.port(), id };
248 Discv4::bind(socket, local_enr, secret_key, config).await.unwrap()
249}
250
251pub fn rng_endpoint(rng: &mut impl Rng) -> NodeEndpoint {
253 let address = if rng.gen() {
254 let mut ip = [0u8; 4];
255 rng.fill_bytes(&mut ip);
256 IpAddr::V4(ip.into())
257 } else {
258 let mut ip = [0u8; 16];
259 rng.fill_bytes(&mut ip);
260 IpAddr::V6(ip.into())
261 };
262 NodeEndpoint { address, tcp_port: rng.gen(), udp_port: rng.gen() }
263}
264
265pub fn rng_record(rng: &mut impl RngCore) -> NodeRecord {
267 let NodeEndpoint { address, udp_port, tcp_port } = rng_endpoint(rng);
268 NodeRecord { address, tcp_port, udp_port, id: rng.gen() }
269}
270
271pub fn rng_ipv6_record(rng: &mut impl RngCore) -> NodeRecord {
273 let mut ip = [0u8; 16];
274 rng.fill_bytes(&mut ip);
275 let address = IpAddr::V6(ip.into());
276 NodeRecord { address, tcp_port: rng.gen(), udp_port: rng.gen(), id: rng.gen() }
277}
278
279pub fn rng_ipv4_record(rng: &mut impl RngCore) -> NodeRecord {
281 let mut ip = [0u8; 4];
282 rng.fill_bytes(&mut ip);
283 let address = IpAddr::V4(ip.into());
284 NodeRecord { address, tcp_port: rng.gen(), udp_port: rng.gen(), id: rng.gen() }
285}
286
287pub fn rng_message(rng: &mut impl RngCore) -> Message {
289 match rng.gen_range(1..=4) {
290 1 => Message::Ping(Ping {
291 from: rng_endpoint(rng),
292 to: rng_endpoint(rng),
293 expire: rng.gen(),
294 enr_sq: None,
295 }),
296 2 => Message::Pong(Pong {
297 to: rng_endpoint(rng),
298 echo: rng.gen(),
299 expire: rng.gen(),
300 enr_sq: None,
301 }),
302 3 => Message::FindNode(FindNode { id: rng.gen(), expire: rng.gen() }),
303 4 => {
304 let num: usize = rng.gen_range(1..=SAFE_MAX_DATAGRAM_NEIGHBOUR_RECORDS);
305 Message::Neighbours(Neighbours {
306 nodes: std::iter::repeat_with(|| rng_record(rng)).take(num).collect(),
307 expire: rng.gen(),
308 })
309 }
310 _ => unreachable!(),
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use crate::Discv4Event;
318 use std::net::Ipv4Addr;
319
320 #[tokio::test]
323 async fn can_mock_discovery() {
324 reth_tracing::init_test_tracing();
325
326 let mut rng = thread_rng();
327 let (_, mut service) = create_discv4().await;
328 let (mut mockv4, _cmd) = MockDiscovery::new().await.unwrap();
329
330 let mock_enr = mockv4.local_enr();
331
332 service.local_enr_mut().address = IpAddr::V4(Ipv4Addr::UNSPECIFIED);
334
335 let discv_addr = service.local_addr();
336 let discv_enr = service.local_enr();
337
338 mockv4.queue_pong(discv_enr.id);
340
341 service.add_node(mock_enr);
343
344 let event = mockv4.next().await.unwrap();
346 match event {
347 MockEvent::Pong { ping: _, pong: _, to } => {
348 assert_eq!(to, SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), discv_addr.port()));
349 }
350 MockEvent::Neighbours { .. } => {
351 unreachable!("invalid response")
352 }
353 }
354
355 let event = service.next().await.unwrap();
357 assert_eq!(event, Discv4Event::Pong);
358
359 assert!(service.contains_node(mock_enr.id));
360
361 let mock_nodes =
362 std::iter::repeat_with(|| rng_record(&mut rng)).take(5).collect::<Vec<_>>();
363
364 mockv4.queue_neighbours(discv_enr.id, mock_nodes.clone());
365
366 service.lookup_self();
368
369 let event = mockv4.next().await.unwrap();
370 match event {
371 MockEvent::Pong { .. } => {
372 unreachable!("invalid response")
373 }
374 MockEvent::Neighbours { nodes, to } => {
375 assert_eq!(to, SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), discv_addr.port()));
376 assert_eq!(nodes, mock_nodes);
377 }
378 }
379
380 let event = service.next().await.unwrap();
382 assert_eq!(event, Discv4Event::Neighbours);
383 }
384}