reth_dns_discovery/
query.rs

1//! Handles query execution
2
3use crate::{
4    error::{LookupError, LookupResult},
5    resolver::Resolver,
6    sync::ResolveKind,
7    tree::{DnsEntry, LinkEntry, TreeRootEntry},
8};
9use enr::EnrKeyUnambiguous;
10use reth_tokio_util::ratelimit::{Rate, RateLimit};
11use std::{
12    collections::VecDeque,
13    future::Future,
14    num::NonZeroUsize,
15    pin::Pin,
16    sync::Arc,
17    task::{ready, Context, Poll},
18    time::Duration,
19};
20
21/// The `QueryPool` provides an aggregate state machine for driving queries to completion.
22pub(crate) struct QueryPool<R: Resolver, K: EnrKeyUnambiguous> {
23    /// The [Resolver] that's used to lookup queries.
24    resolver: Arc<R>,
25    /// Buffered queries
26    queued_queries: VecDeque<Query<K>>,
27    /// All active queries
28    active_queries: Vec<Query<K>>,
29    /// buffered results
30    queued_outcomes: VecDeque<QueryOutcome<K>>,
31    /// Rate limit for DNS requests
32    rate_limit: RateLimit,
33    /// Timeout for DNS lookups.
34    lookup_timeout: Duration,
35}
36
37// === impl QueryPool ===
38
39impl<R: Resolver, K: EnrKeyUnambiguous> QueryPool<R, K> {
40    pub(crate) fn new(
41        resolver: Arc<R>,
42        max_requests_per_sec: NonZeroUsize,
43        lookup_timeout: Duration,
44    ) -> Self {
45        Self {
46            resolver,
47            queued_queries: Default::default(),
48            active_queries: vec![],
49            queued_outcomes: Default::default(),
50            rate_limit: RateLimit::new(Rate::new(
51                max_requests_per_sec.get() as u64,
52                Duration::from_secs(1),
53            )),
54            lookup_timeout,
55        }
56    }
57
58    /// Resolves the root the link's domain references
59    pub(crate) fn resolve_root(&mut self, link: LinkEntry<K>) {
60        let resolver = Arc::clone(&self.resolver);
61        let timeout = self.lookup_timeout;
62        self.queued_queries.push_back(Query::Root(Box::pin(resolve_root(resolver, link, timeout))))
63    }
64
65    /// Resolves the [`DnsEntry`] for `<hash.domain>`
66    pub(crate) fn resolve_entry(&mut self, link: LinkEntry<K>, hash: String, kind: ResolveKind) {
67        let resolver = Arc::clone(&self.resolver);
68        let timeout = self.lookup_timeout;
69        self.queued_queries
70            .push_back(Query::Entry(Box::pin(resolve_entry(resolver, link, hash, kind, timeout))))
71    }
72
73    /// Advances the state of the queries
74    pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<QueryOutcome<K>> {
75        loop {
76            // drain buffered events first
77            if let Some(event) = self.queued_outcomes.pop_front() {
78                return Poll::Ready(event)
79            }
80
81            // queue in new queries if we have capacity
82            'queries: while self.active_queries.len() < self.rate_limit.limit() as usize {
83                if self.rate_limit.poll_ready(cx).is_ready() {
84                    if let Some(query) = self.queued_queries.pop_front() {
85                        self.rate_limit.tick();
86                        self.active_queries.push(query);
87                        continue 'queries
88                    }
89                }
90                break
91            }
92
93            // advance all queries
94            for idx in (0..self.active_queries.len()).rev() {
95                let mut query = self.active_queries.swap_remove(idx);
96                if let Poll::Ready(outcome) = query.poll(cx) {
97                    self.queued_outcomes.push_back(outcome);
98                } else {
99                    // still pending
100                    self.active_queries.push(query);
101                }
102            }
103
104            if self.queued_outcomes.is_empty() {
105                return Poll::Pending
106            }
107        }
108    }
109}
110
111// === Various future/type alias ===
112
113pub(crate) struct ResolveEntryResult<K: EnrKeyUnambiguous> {
114    pub(crate) entry: Option<LookupResult<DnsEntry<K>>>,
115    pub(crate) link: LinkEntry<K>,
116    pub(crate) hash: String,
117    pub(crate) kind: ResolveKind,
118}
119
120pub(crate) type ResolveRootResult<K> =
121    Result<(TreeRootEntry, LinkEntry<K>), (LookupError, LinkEntry<K>)>;
122
123type ResolveRootFuture<K> = Pin<Box<dyn Future<Output = ResolveRootResult<K>> + Send>>;
124
125type ResolveEntryFuture<K> = Pin<Box<dyn Future<Output = ResolveEntryResult<K>> + Send>>;
126
127enum Query<K: EnrKeyUnambiguous> {
128    Root(ResolveRootFuture<K>),
129    Entry(ResolveEntryFuture<K>),
130}
131
132// === impl Query ===
133
134impl<K: EnrKeyUnambiguous> Query<K> {
135    /// Advances the query
136    fn poll(&mut self, cx: &mut Context<'_>) -> Poll<QueryOutcome<K>> {
137        match self {
138            Self::Root(query) => {
139                let outcome = ready!(query.as_mut().poll(cx));
140                Poll::Ready(QueryOutcome::Root(outcome))
141            }
142            Self::Entry(query) => {
143                let outcome = ready!(query.as_mut().poll(cx));
144                Poll::Ready(QueryOutcome::Entry(outcome))
145            }
146        }
147    }
148}
149
150/// The output the queries return
151pub(crate) enum QueryOutcome<K: EnrKeyUnambiguous> {
152    Root(ResolveRootResult<K>),
153    Entry(ResolveEntryResult<K>),
154}
155
156/// Retrieves the [`DnsEntry`]
157async fn resolve_entry<K: EnrKeyUnambiguous, R: Resolver>(
158    resolver: Arc<R>,
159    link: LinkEntry<K>,
160    hash: String,
161    kind: ResolveKind,
162    timeout: Duration,
163) -> ResolveEntryResult<K> {
164    let fqn = format!("{hash}.{}", link.domain);
165    let mut resp = ResolveEntryResult { entry: None, link, hash, kind };
166    match lookup_with_timeout::<R>(&resolver, &fqn, timeout).await {
167        Ok(Some(entry)) => {
168            resp.entry = Some(entry.parse::<DnsEntry<K>>().map_err(|err| err.into()))
169        }
170        Err(err) => resp.entry = Some(Err(err)),
171        Ok(None) => {}
172    }
173    resp
174}
175
176/// Retrieves the root entry the link points to and returns the verified entry
177///
178/// Returns an error if the record could be retrieved but is not a root entry or failed to be
179/// verified.
180async fn resolve_root<K: EnrKeyUnambiguous, R: Resolver>(
181    resolver: Arc<R>,
182    link: LinkEntry<K>,
183    timeout: Duration,
184) -> ResolveRootResult<K> {
185    let root = match lookup_with_timeout::<R>(&resolver, &link.domain, timeout).await {
186        Ok(Some(root)) => root,
187        Ok(_) => return Err((LookupError::EntryNotFound, link)),
188        Err(err) => return Err((err, link)),
189    };
190
191    match root.parse::<TreeRootEntry>() {
192        Ok(root) => {
193            if root.verify::<K>(&link.pubkey) {
194                Ok((root, link))
195            } else {
196                Err((LookupError::InvalidRoot(root), link))
197            }
198        }
199        Err(err) => Err((err.into(), link)),
200    }
201}
202
203async fn lookup_with_timeout<R: Resolver>(
204    r: &R,
205    query: &str,
206    timeout: Duration,
207) -> LookupResult<Option<String>> {
208    tokio::time::timeout(timeout, r.lookup_txt(query))
209        .await
210        .map_err(|_| LookupError::RequestTimedOut)
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use crate::{resolver::TimeoutResolver, DnsDiscoveryConfig, MapResolver};
217    use std::future::poll_fn;
218
219    #[tokio::test]
220    async fn test_rate_limit() {
221        let resolver = Arc::new(MapResolver::default());
222        let config = DnsDiscoveryConfig::default();
223        let mut pool = QueryPool::new(resolver, config.max_requests_per_sec, config.lookup_timeout);
224
225        let s = "enrtree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3YYHIUV7BVDQ5FDPRT2@nodes.example.org";
226        let entry: LinkEntry = s.parse().unwrap();
227
228        for _n in 0..config.max_requests_per_sec.get() {
229            poll_fn(|cx| {
230                pool.resolve_root(entry.clone());
231                assert_eq!(pool.queued_queries.len(), 1);
232                assert!(pool.rate_limit.poll_ready(cx).is_ready());
233                let _ = pool.poll(cx);
234                assert_eq!(pool.queued_queries.len(), 0);
235                Poll::Ready(())
236            })
237            .await;
238        }
239
240        pool.resolve_root(entry.clone());
241        assert_eq!(pool.queued_queries.len(), 1);
242        poll_fn(|cx| {
243            assert!(pool.rate_limit.poll_ready(cx).is_pending());
244            let _ = pool.poll(cx);
245            assert_eq!(pool.queued_queries.len(), 1);
246            Poll::Ready(())
247        })
248        .await;
249    }
250
251    #[tokio::test]
252    async fn test_timeouts() {
253        let config =
254            DnsDiscoveryConfig { lookup_timeout: Duration::from_millis(500), ..Default::default() };
255        let resolver = Arc::new(TimeoutResolver(config.lookup_timeout * 2));
256        let mut pool = QueryPool::new(resolver, config.max_requests_per_sec, config.lookup_timeout);
257
258        let s = "enrtree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3YYHIUV7BVDQ5FDPRT2@nodes.example.org";
259        let entry: LinkEntry = s.parse().unwrap();
260        pool.resolve_root(entry);
261
262        let outcome = poll_fn(|cx| pool.poll(cx)).await;
263
264        match outcome {
265            QueryOutcome::Root(res) => {
266                let res = res.unwrap_err().0;
267                match res {
268                    LookupError::RequestTimedOut => {}
269                    _ => unreachable!(),
270                }
271            }
272            QueryOutcome::Entry(_) => {
273                unreachable!()
274            }
275        }
276    }
277}