1use crate::{
4 error::{LookupError, LookupResult},
5 resolver::Resolver,
6 sync::ResolveKind,
7 tree::{DnsEntry, LinkEntry, TreeRootEntry},
8};
9use alloy_primitives::keccak256;
10use data_encoding::BASE32_NOPAD;
11use enr::EnrKeyUnambiguous;
12use reth_tokio_util::ratelimit::{Rate, RateLimit};
13use std::{
14 collections::VecDeque,
15 future::Future,
16 num::NonZeroUsize,
17 pin::Pin,
18 sync::Arc,
19 task::{ready, Context, Poll},
20 time::Duration,
21};
22
23const MIN_HASH_BYTES: usize = 12;
25const MAX_HASH_BYTES: usize = 32;
27
28pub(crate) struct QueryPool<R: Resolver, K: EnrKeyUnambiguous> {
30 resolver: Arc<R>,
32 queued_queries: VecDeque<Query<K>>,
34 active_queries: Vec<Query<K>>,
36 queued_outcomes: VecDeque<QueryOutcome<K>>,
38 rate_limit: RateLimit,
40 lookup_timeout: Duration,
42}
43
44impl<R: Resolver, K: EnrKeyUnambiguous> QueryPool<R, K> {
47 pub(crate) fn new(
48 resolver: Arc<R>,
49 max_requests_per_sec: NonZeroUsize,
50 lookup_timeout: Duration,
51 ) -> Self {
52 Self {
53 resolver,
54 queued_queries: Default::default(),
55 active_queries: vec![],
56 queued_outcomes: Default::default(),
57 rate_limit: RateLimit::new(Rate::new(
58 max_requests_per_sec.get() as u64,
59 Duration::from_secs(1),
60 )),
61 lookup_timeout,
62 }
63 }
64
65 pub(crate) fn resolve_root(&mut self, link: LinkEntry<K>) {
67 let resolver = Arc::clone(&self.resolver);
68 let timeout = self.lookup_timeout;
69 self.queued_queries.push_back(Query::Root(Box::pin(resolve_root(resolver, link, timeout))))
70 }
71
72 pub(crate) fn resolve_entry(&mut self, link: LinkEntry<K>, hash: String, kind: ResolveKind) {
74 let resolver = Arc::clone(&self.resolver);
75 let timeout = self.lookup_timeout;
76 self.queued_queries
77 .push_back(Query::Entry(Box::pin(resolve_entry(resolver, link, hash, kind, timeout))))
78 }
79
80 pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<QueryOutcome<K>> {
82 loop {
83 if let Some(event) = self.queued_outcomes.pop_front() {
85 return Poll::Ready(event)
86 }
87
88 'queries: while self.active_queries.len() < self.rate_limit.limit() as usize {
90 if self.rate_limit.poll_ready(cx).is_ready() &&
91 let Some(query) = self.queued_queries.pop_front()
92 {
93 self.rate_limit.tick();
94 self.active_queries.push(query);
95 continue 'queries
96 }
97 break
98 }
99
100 for idx in (0..self.active_queries.len()).rev() {
102 let mut query = self.active_queries.swap_remove(idx);
103 if let Poll::Ready(outcome) = query.poll(cx) {
104 self.queued_outcomes.push_back(outcome);
105 } else {
106 self.active_queries.push(query);
108 }
109 }
110
111 if self.queued_outcomes.is_empty() {
112 return Poll::Pending
113 }
114 }
115 }
116}
117
118pub(crate) struct ResolveEntryResult<K: EnrKeyUnambiguous> {
121 pub(crate) entry: Option<LookupResult<DnsEntry<K>>>,
122 pub(crate) link: LinkEntry<K>,
123 pub(crate) hash: String,
124 pub(crate) kind: ResolveKind,
125}
126
127pub(crate) type ResolveRootResult<K> =
128 Result<(TreeRootEntry, LinkEntry<K>), (LookupError, LinkEntry<K>)>;
129
130type ResolveRootFuture<K> = Pin<Box<dyn Future<Output = ResolveRootResult<K>> + Send>>;
131
132type ResolveEntryFuture<K> = Pin<Box<dyn Future<Output = ResolveEntryResult<K>> + Send>>;
133
134enum Query<K: EnrKeyUnambiguous> {
135 Root(ResolveRootFuture<K>),
136 Entry(ResolveEntryFuture<K>),
137}
138
139impl<K: EnrKeyUnambiguous> Query<K> {
142 fn poll(&mut self, cx: &mut Context<'_>) -> Poll<QueryOutcome<K>> {
144 match self {
145 Self::Root(query) => {
146 let outcome = ready!(query.as_mut().poll(cx));
147 Poll::Ready(QueryOutcome::Root(outcome))
148 }
149 Self::Entry(query) => {
150 let outcome = ready!(query.as_mut().poll(cx));
151 Poll::Ready(QueryOutcome::Entry(outcome))
152 }
153 }
154 }
155}
156
157pub(crate) enum QueryOutcome<K: EnrKeyUnambiguous> {
159 Root(ResolveRootResult<K>),
160 Entry(ResolveEntryResult<K>),
161}
162
163async fn resolve_entry<K: EnrKeyUnambiguous, R: Resolver>(
165 resolver: Arc<R>,
166 link: LinkEntry<K>,
167 hash: String,
168 kind: ResolveKind,
169 timeout: Duration,
170) -> ResolveEntryResult<K> {
171 let fqn = format!("{hash}.{}", link.domain);
172 let mut resp = ResolveEntryResult { entry: None, link, hash, kind };
173 match lookup_with_timeout::<R>(&resolver, &fqn, timeout).await {
174 Ok(Some(entry)) => {
175 resp.entry = Some(match verify_entry_hash(&resp.hash, &entry) {
176 Ok(()) => entry.parse::<DnsEntry<K>>().map_err(Into::into),
177 Err(err) => Err(err),
178 })
179 }
180 Err(err) => resp.entry = Some(Err(err)),
181 Ok(None) => {}
182 }
183 resp
184}
185
186fn verify_entry_hash(hash: &str, entry_txt: &str) -> LookupResult<()> {
193 let expected =
194 BASE32_NOPAD.decode(hash.as_bytes()).map_err(|_| LookupError::HashMismatch(hash.into()))?;
195 let actual = keccak256(entry_txt.as_bytes());
196
197 if !(MIN_HASH_BYTES..=MAX_HASH_BYTES).contains(&expected.len()) {
198 return Err(LookupError::HashMismatch(hash.into()))
199 }
200
201 if actual.as_slice().starts_with(&expected) {
202 Ok(())
203 } else {
204 Err(LookupError::HashMismatch(hash.into()))
205 }
206}
207
208async fn resolve_root<K: EnrKeyUnambiguous, R: Resolver>(
213 resolver: Arc<R>,
214 link: LinkEntry<K>,
215 timeout: Duration,
216) -> ResolveRootResult<K> {
217 let root = match lookup_with_timeout::<R>(&resolver, &link.domain, timeout).await {
218 Ok(Some(root)) => root,
219 Ok(_) => return Err((LookupError::EntryNotFound, link)),
220 Err(err) => return Err((err, link)),
221 };
222
223 match root.parse::<TreeRootEntry>() {
224 Ok(root) => {
225 if root.verify::<K>(&link.pubkey) {
226 Ok((root, link))
227 } else {
228 Err((LookupError::InvalidRoot(root), link))
229 }
230 }
231 Err(err) => Err((err.into(), link)),
232 }
233}
234
235async fn lookup_with_timeout<R: Resolver>(
236 r: &R,
237 query: &str,
238 timeout: Duration,
239) -> LookupResult<Option<String>> {
240 tokio::time::timeout(timeout, r.lookup_txt(query))
241 .await
242 .map_err(|_| LookupError::RequestTimedOut)
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use crate::{resolver::TimeoutResolver, DnsDiscoveryConfig, MapResolver};
249 use std::future::poll_fn;
250
251 fn entry_hash(entry_txt: &str) -> String {
252 BASE32_NOPAD.encode(&keccak256(entry_txt.as_bytes()).as_slice()[..16])
253 }
254
255 #[tokio::test]
256 async fn test_rate_limit() {
257 let resolver = Arc::new(MapResolver::default());
258 let config = DnsDiscoveryConfig::default();
259 let mut pool = QueryPool::new(resolver, config.max_requests_per_sec, config.lookup_timeout);
260
261 let s = "enrtree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3YYHIUV7BVDQ5FDPRT2@nodes.example.org";
262 let entry: LinkEntry = s.parse().unwrap();
263
264 for _n in 0..config.max_requests_per_sec.get() {
265 poll_fn(|cx| {
266 pool.resolve_root(entry.clone());
267 assert_eq!(pool.queued_queries.len(), 1);
268 assert!(pool.rate_limit.poll_ready(cx).is_ready());
269 let _ = pool.poll(cx);
270 assert_eq!(pool.queued_queries.len(), 0);
271 Poll::Ready(())
272 })
273 .await;
274 }
275
276 pool.resolve_root(entry.clone());
277 assert_eq!(pool.queued_queries.len(), 1);
278 poll_fn(|cx| {
279 assert!(pool.rate_limit.poll_ready(cx).is_pending());
280 let _ = pool.poll(cx);
281 assert_eq!(pool.queued_queries.len(), 1);
282 Poll::Ready(())
283 })
284 .await;
285 }
286
287 #[tokio::test]
288 async fn test_timeouts() {
289 let config =
290 DnsDiscoveryConfig { lookup_timeout: Duration::from_millis(500), ..Default::default() };
291 let resolver = Arc::new(TimeoutResolver(config.lookup_timeout * 2));
292 let mut pool = QueryPool::new(resolver, config.max_requests_per_sec, config.lookup_timeout);
293
294 let s = "enrtree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3YYHIUV7BVDQ5FDPRT2@nodes.example.org";
295 let entry: LinkEntry = s.parse().unwrap();
296 pool.resolve_root(entry);
297
298 let outcome = poll_fn(|cx| pool.poll(cx)).await;
299
300 match outcome {
301 QueryOutcome::Root(res) => {
302 let res = res.unwrap_err().0;
303 match res {
304 LookupError::RequestTimedOut => {}
305 _ => unreachable!(),
306 }
307 }
308 QueryOutcome::Entry(_) => {
309 unreachable!()
310 }
311 }
312 }
313
314 #[test]
315 fn verify_entry_hash_accepts_eip_1459_vectors() {
316 let entries = [
317 (
318 "C7HRFPF3BLGF3YR4DY5KX3SMBE",
319 "enrtree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3YYHIUV7BVDQ5FDPRT2@morenodes.example.org",
320 ),
321 (
322 "JWXYDBPXYWG6FX3GMDIBFA6CJ4",
323 "enrtree-branch:2XS2367YHAXJFGLZHVAWLQD4ZY,H4FHT4B454P6UXFD7JCYQ5PWDY,MHTDO6TMUBRIA2XWG5LUDACK24",
324 ),
325 (
326 "2XS2367YHAXJFGLZHVAWLQD4ZY",
327 "enr:-HW4QOFzoVLaFJnNhbgMoDXPnOvcdVuj7pDpqRvh6BRDO68aVi5ZcjB3vzQRZH2IcLBGHzo8uUN3snqmgTiE56CH3AMBgmlkgnY0iXNlY3AyNTZrMaECC2_24YYkYHEgdzxlSNKQEnHhuNAbNlMlWJxrJxbAFvA",
328 ),
329 ];
330
331 for (hash, entry) in entries {
332 verify_entry_hash(hash, entry).unwrap();
333 }
334 }
335
336 #[test]
337 fn verify_entry_hash_rejects_mismatched_or_invalid_hashes() {
338 let entry = "enrtree-branch:YNEGZIWHOM7TOOSUATAPTM";
339 let hash = entry_hash(entry);
340 verify_entry_hash(&hash, entry).unwrap();
341
342 assert!(matches!(
343 verify_entry_hash(&hash, "enrtree-branch:AAAAAAAAAAAAAAAAAAAA"),
344 Err(LookupError::HashMismatch(_))
345 ));
346 assert!(matches!(
347 verify_entry_hash("NOT_BASE32!", entry),
348 Err(LookupError::HashMismatch(_))
349 ));
350 assert!(matches!(verify_entry_hash("AAAA", entry), Err(LookupError::HashMismatch(_))));
351 }
352}