1use crate::{
6 proto::{FindNode, Message, Neighbours, NodeEndpoint, Packet, Ping, Pong},
7 receive_loop, send_loop, Discv4, Discv4Config, Discv4Service, EgressSender, IngressEvent,
8 IngressReceiver, PeerId, SAFE_MAX_DATAGRAM_NEIGHBOUR_RECORDS,
9};
10use alloy_primitives::{hex, B256, B512};
11use rand_08::{thread_rng, Rng, RngCore};
12use reth_ethereum_forks::{ForkHash, ForkId};
13use reth_network_peers::{pk2id, NodeRecord};
14use secp256k1::{SecretKey, SECP256K1};
15use std::{
16 collections::{HashMap, HashSet},
17 io,
18 net::{IpAddr, SocketAddr},
19 pin::Pin,
20 str::FromStr,
21 sync::Arc,
22 task::{Context, Poll},
23 time::{Duration, SystemTime, UNIX_EPOCH},
24};
25use tokio::{
26 net::UdpSocket,
27 sync::mpsc,
28 task::{JoinHandle, JoinSet},
29};
30use tokio_stream::{Stream, StreamExt};
31use tracing::debug;
32
33#[derive(Debug)]
35pub struct MockDiscovery {
36 local_addr: SocketAddr,
37 local_enr: NodeRecord,
38 secret_key: SecretKey,
39 _udp: Arc<UdpSocket>,
40 _tasks: JoinSet<()>,
41 ingress: IngressReceiver,
43 egress: EgressSender,
45 pending_pongs: HashSet<PeerId>,
46 pending_neighbours: HashMap<PeerId, Vec<NodeRecord>>,
47 command_rx: mpsc::Receiver<MockCommand>,
48}
49
50impl MockDiscovery {
51 pub async fn new() -> io::Result<(Self, mpsc::Sender<MockCommand>)> {
53 let mut rng = thread_rng();
54 let socket = SocketAddr::from_str("0.0.0.0:0").unwrap();
55 let (secret_key, pk) = SECP256K1.generate_keypair(&mut rng);
56 let id = pk2id(&pk);
57 let socket = Arc::new(UdpSocket::bind(socket).await?);
58 let local_addr = socket.local_addr()?;
59 let local_enr = NodeRecord {
60 address: local_addr.ip(),
61 tcp_port: local_addr.port(),
62 udp_port: local_addr.port(),
63 id,
64 };
65
66 let (ingress_tx, ingress_rx) = mpsc::channel(128);
67 let (egress_tx, egress_rx) = mpsc::channel(128);
68 let mut tasks = JoinSet::<()>::new();
69
70 let udp = Arc::clone(&socket);
71 tasks.spawn(receive_loop(udp, ingress_tx, local_enr.id));
72
73 let udp = Arc::clone(&socket);
74 tasks.spawn(send_loop(udp, egress_rx));
75
76 let (tx, command_rx) = mpsc::channel(128);
77 let this = Self {
78 _tasks: tasks,
79 ingress: ingress_rx,
80 egress: egress_tx,
81 local_addr,
82 local_enr,
83 secret_key,
84 _udp: socket,
85 pending_pongs: Default::default(),
86 pending_neighbours: Default::default(),
87 command_rx,
88 };
89 Ok((this, tx))
90 }
91
92 pub fn spawn(self) -> JoinHandle<()> {
94 tokio::task::spawn(async move {
95 let _: Vec<_> = self.collect().await;
96 })
97 }
98
99 pub fn queue_pong(&mut self, from: PeerId) {
101 self.pending_pongs.insert(from);
102 }
103
104 pub fn queue_neighbours(&mut self, target: PeerId, nodes: Vec<NodeRecord>) {
106 self.pending_neighbours.insert(target, nodes);
107 }
108
109 pub const fn local_addr(&self) -> SocketAddr {
111 self.local_addr
112 }
113
114 pub const fn local_enr(&self) -> NodeRecord {
116 self.local_enr
117 }
118
119 fn send_packet(&self, msg: Message, to: SocketAddr) -> B256 {
121 let (payload, hash) = msg.encode(&self.secret_key);
122 let _ = self.egress.try_send((payload, to));
123 hash
124 }
125
126 fn send_neighbours_timeout(&self) -> u64 {
127 (SystemTime::now().duration_since(UNIX_EPOCH).unwrap() + Duration::from_secs(30)).as_secs()
128 }
129}
130
131impl Stream for MockDiscovery {
132 type Item = MockEvent;
133
134 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
135 let this = self.get_mut();
136 while let Poll::Ready(maybe_cmd) = this.command_rx.poll_recv(cx) {
138 let Some(cmd) = maybe_cmd else { return Poll::Ready(None) };
139 match cmd {
140 MockCommand::MockPong { node_id } => {
141 this.queue_pong(node_id);
142 }
143 MockCommand::MockNeighbours { target, nodes } => {
144 this.queue_neighbours(target, nodes);
145 }
146 }
147 }
148
149 while let Poll::Ready(Some(event)) = this.ingress.poll_recv(cx) {
150 match event {
151 IngressEvent::RecvError(_) => {}
152 IngressEvent::BadPacket(from, err, data) => {
153 debug!(target: "discv4", ?from, %err, packet=?hex::encode(&data), "bad packet");
154 }
155 IngressEvent::Packet(remote_addr, Packet { msg, node_id, hash }) => match msg {
156 Message::Ping(ping) => {
157 if this.pending_pongs.remove(&node_id) {
158 let pong = Pong {
159 to: ping.from,
160 echo: hash,
161 expire: ping.expire,
162 enr_sq: None,
163 };
164 let msg = Message::Pong(pong.clone());
165 this.send_packet(msg, remote_addr);
166 return Poll::Ready(Some(MockEvent::Pong {
167 ping,
168 pong,
169 to: remote_addr,
170 }))
171 }
172 }
173 Message::Pong(_) | Message::Neighbours(_) => {}
174 Message::FindNode(msg) => {
175 if let Some(nodes) = this.pending_neighbours.remove(&msg.id) {
176 let msg = Message::Neighbours(Neighbours {
177 nodes: nodes.clone(),
178 expire: this.send_neighbours_timeout(),
179 });
180 this.send_packet(msg, remote_addr);
181 return Poll::Ready(Some(MockEvent::Neighbours {
182 nodes,
183 to: remote_addr,
184 }))
185 }
186 }
187 Message::EnrRequest(_) | Message::EnrResponse(_) => todo!(),
188 },
189 }
190 }
191
192 Poll::Pending
193 }
194}
195
196#[derive(Debug)]
198pub enum MockEvent {
199 Pong {
202 ping: Ping,
204 pong: Pong,
206 to: SocketAddr,
208 },
209 Neighbours {
211 nodes: Vec<NodeRecord>,
213 to: SocketAddr,
215 },
216}
217
218#[derive(Debug)]
220pub enum MockCommand {
221 MockPong {
223 node_id: PeerId,
225 },
226 MockNeighbours {
229 target: PeerId,
231 nodes: Vec<NodeRecord>,
233 },
234}
235
236pub async fn create_discv4() -> (Discv4, Discv4Service) {
238 let fork_id = ForkId { hash: ForkHash(hex!("743f3d89")), next: 16191202 };
239 create_discv4_with_config(Discv4Config::builder().add_eip868_pair("eth", fork_id).build()).await
240}
241
242pub async fn create_discv4_with_config(config: Discv4Config) -> (Discv4, Discv4Service) {
244 let mut rng = thread_rng();
245 let socket = SocketAddr::from_str("0.0.0.0:0").unwrap();
246 let (secret_key, pk) = SECP256K1.generate_keypair(&mut rng);
247 let id = pk2id(&pk);
248 let local_enr =
249 NodeRecord { address: socket.ip(), tcp_port: socket.port(), udp_port: socket.port(), id };
250 Discv4::bind(socket, local_enr, secret_key, config).await.unwrap()
251}
252
253pub fn rng_endpoint(rng: &mut impl Rng) -> NodeEndpoint {
255 let address = if rng.gen() {
256 let mut ip = [0u8; 4];
257 rng.fill_bytes(&mut ip);
258 IpAddr::V4(ip.into())
259 } else {
260 let mut ip = [0u8; 16];
261 rng.fill_bytes(&mut ip);
262 IpAddr::V6(ip.into())
263 };
264 NodeEndpoint { address, tcp_port: rng.gen(), udp_port: rng.gen() }
265}
266
267pub fn rng_record(rng: &mut impl RngCore) -> NodeRecord {
269 let NodeEndpoint { address, udp_port, tcp_port } = rng_endpoint(rng);
270 NodeRecord { address, tcp_port, udp_port, id: B512::random() }
272}
273
274pub fn rng_ipv6_record(rng: &mut impl RngCore) -> NodeRecord {
276 let mut ip = [0u8; 16];
277 rng.fill_bytes(&mut ip);
278 let address = IpAddr::V6(ip.into());
279 NodeRecord { address, tcp_port: rng.gen(), udp_port: rng.gen(), id: B512::random() }
281}
282
283pub fn rng_ipv4_record(rng: &mut impl RngCore) -> NodeRecord {
285 let mut ip = [0u8; 4];
286 rng.fill_bytes(&mut ip);
287 let address = IpAddr::V4(ip.into());
288 NodeRecord { address, tcp_port: rng.gen(), udp_port: rng.gen(), id: B512::random() }
290}
291
292pub fn rng_message(rng: &mut impl RngCore) -> Message {
294 match rng.gen_range(1..=4) {
295 1 => Message::Ping(Ping {
296 from: rng_endpoint(rng),
297 to: rng_endpoint(rng),
298 expire: rng.gen(),
299 enr_sq: None,
300 }),
301 2 => Message::Pong(Pong {
302 to: rng_endpoint(rng),
303 echo: B256::random(),
304 expire: rng.gen(),
305 enr_sq: None,
306 }),
307 3 => Message::FindNode(FindNode { id: B512::random(), expire: rng.gen() }),
308 4 => {
309 let num: usize = rng.gen_range(1..=SAFE_MAX_DATAGRAM_NEIGHBOUR_RECORDS);
310 Message::Neighbours(Neighbours {
311 nodes: std::iter::repeat_with(|| rng_record(rng)).take(num).collect(),
312 expire: rng.gen(),
313 })
314 }
315 _ => unreachable!(),
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use crate::Discv4Event;
323 use std::net::Ipv4Addr;
324
325 #[tokio::test]
328 async fn can_mock_discovery() {
329 reth_tracing::init_test_tracing();
330
331 let mut rng = thread_rng();
332 let (_, mut service) = create_discv4().await;
333 let (mut mockv4, _cmd) = MockDiscovery::new().await.unwrap();
334
335 let mock_enr = mockv4.local_enr();
336
337 service.local_enr_mut().address = IpAddr::V4(Ipv4Addr::UNSPECIFIED);
339
340 let discv_addr = service.local_addr();
341 let discv_enr = service.local_enr();
342
343 mockv4.queue_pong(discv_enr.id);
345
346 service.add_node(mock_enr);
348
349 let event = mockv4.next().await.unwrap();
351 match event {
352 MockEvent::Pong { ping: _, pong: _, to } => {
353 assert_eq!(to, SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), discv_addr.port()));
354 }
355 MockEvent::Neighbours { .. } => {
356 unreachable!("invalid response")
357 }
358 }
359
360 let event = service.next().await.unwrap();
362 assert_eq!(event, Discv4Event::Pong);
363
364 assert!(service.contains_node(mock_enr.id));
365
366 let mock_nodes =
367 std::iter::repeat_with(|| rng_record(&mut rng)).take(5).collect::<Vec<_>>();
368
369 mockv4.queue_neighbours(discv_enr.id, mock_nodes.clone());
370
371 service.lookup_self();
373
374 let event = mockv4.next().await.unwrap();
375 match event {
376 MockEvent::Pong { .. } => {
377 unreachable!("invalid response")
378 }
379 MockEvent::Neighbours { nodes, to } => {
380 assert_eq!(to, SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), discv_addr.port()));
381 assert_eq!(nodes, mock_nodes);
382 }
383 }
384
385 let event = service.next().await.unwrap();
387 assert_eq!(event, Discv4Event::Neighbours);
388 }
389}