1use crate::{
4 error::{LookupError, LookupResult},
5 resolver::Resolver,
6 sync::ResolveKind,
7 tree::{DnsEntry, LinkEntry, TreeRootEntry},
8};
9use enr::EnrKeyUnambiguous;
10use reth_tokio_util::ratelimit::{Rate, RateLimit};
11use std::{
12 collections::VecDeque,
13 future::Future,
14 num::NonZeroUsize,
15 pin::Pin,
16 sync::Arc,
17 task::{ready, Context, Poll},
18 time::Duration,
19};
20
21pub(crate) struct QueryPool<R: Resolver, K: EnrKeyUnambiguous> {
23 resolver: Arc<R>,
25 queued_queries: VecDeque<Query<K>>,
27 active_queries: Vec<Query<K>>,
29 queued_outcomes: VecDeque<QueryOutcome<K>>,
31 rate_limit: RateLimit,
33 lookup_timeout: Duration,
35}
36
37impl<R: Resolver, K: EnrKeyUnambiguous> QueryPool<R, K> {
40 pub(crate) fn new(
41 resolver: Arc<R>,
42 max_requests_per_sec: NonZeroUsize,
43 lookup_timeout: Duration,
44 ) -> Self {
45 Self {
46 resolver,
47 queued_queries: Default::default(),
48 active_queries: vec![],
49 queued_outcomes: Default::default(),
50 rate_limit: RateLimit::new(Rate::new(
51 max_requests_per_sec.get() as u64,
52 Duration::from_secs(1),
53 )),
54 lookup_timeout,
55 }
56 }
57
58 pub(crate) fn resolve_root(&mut self, link: LinkEntry<K>) {
60 let resolver = Arc::clone(&self.resolver);
61 let timeout = self.lookup_timeout;
62 self.queued_queries.push_back(Query::Root(Box::pin(resolve_root(resolver, link, timeout))))
63 }
64
65 pub(crate) fn resolve_entry(&mut self, link: LinkEntry<K>, hash: String, kind: ResolveKind) {
67 let resolver = Arc::clone(&self.resolver);
68 let timeout = self.lookup_timeout;
69 self.queued_queries
70 .push_back(Query::Entry(Box::pin(resolve_entry(resolver, link, hash, kind, timeout))))
71 }
72
73 pub(crate) fn poll(&mut self, cx: &mut Context<'_>) -> Poll<QueryOutcome<K>> {
75 loop {
76 if let Some(event) = self.queued_outcomes.pop_front() {
78 return Poll::Ready(event)
79 }
80
81 'queries: while self.active_queries.len() < self.rate_limit.limit() as usize {
83 if self.rate_limit.poll_ready(cx).is_ready() {
84 if let Some(query) = self.queued_queries.pop_front() {
85 self.rate_limit.tick();
86 self.active_queries.push(query);
87 continue 'queries
88 }
89 }
90 break
91 }
92
93 for idx in (0..self.active_queries.len()).rev() {
95 let mut query = self.active_queries.swap_remove(idx);
96 if let Poll::Ready(outcome) = query.poll(cx) {
97 self.queued_outcomes.push_back(outcome);
98 } else {
99 self.active_queries.push(query);
101 }
102 }
103
104 if self.queued_outcomes.is_empty() {
105 return Poll::Pending
106 }
107 }
108 }
109}
110
111pub(crate) struct ResolveEntryResult<K: EnrKeyUnambiguous> {
114 pub(crate) entry: Option<LookupResult<DnsEntry<K>>>,
115 pub(crate) link: LinkEntry<K>,
116 pub(crate) hash: String,
117 pub(crate) kind: ResolveKind,
118}
119
120pub(crate) type ResolveRootResult<K> =
121 Result<(TreeRootEntry, LinkEntry<K>), (LookupError, LinkEntry<K>)>;
122
123type ResolveRootFuture<K> = Pin<Box<dyn Future<Output = ResolveRootResult<K>> + Send>>;
124
125type ResolveEntryFuture<K> = Pin<Box<dyn Future<Output = ResolveEntryResult<K>> + Send>>;
126
127enum Query<K: EnrKeyUnambiguous> {
128 Root(ResolveRootFuture<K>),
129 Entry(ResolveEntryFuture<K>),
130}
131
132impl<K: EnrKeyUnambiguous> Query<K> {
135 fn poll(&mut self, cx: &mut Context<'_>) -> Poll<QueryOutcome<K>> {
137 match self {
138 Self::Root(query) => {
139 let outcome = ready!(query.as_mut().poll(cx));
140 Poll::Ready(QueryOutcome::Root(outcome))
141 }
142 Self::Entry(query) => {
143 let outcome = ready!(query.as_mut().poll(cx));
144 Poll::Ready(QueryOutcome::Entry(outcome))
145 }
146 }
147 }
148}
149
150pub(crate) enum QueryOutcome<K: EnrKeyUnambiguous> {
152 Root(ResolveRootResult<K>),
153 Entry(ResolveEntryResult<K>),
154}
155
156async fn resolve_entry<K: EnrKeyUnambiguous, R: Resolver>(
158 resolver: Arc<R>,
159 link: LinkEntry<K>,
160 hash: String,
161 kind: ResolveKind,
162 timeout: Duration,
163) -> ResolveEntryResult<K> {
164 let fqn = format!("{hash}.{}", link.domain);
165 let mut resp = ResolveEntryResult { entry: None, link, hash, kind };
166 match lookup_with_timeout::<R>(&resolver, &fqn, timeout).await {
167 Ok(Some(entry)) => {
168 resp.entry = Some(entry.parse::<DnsEntry<K>>().map_err(|err| err.into()))
169 }
170 Err(err) => resp.entry = Some(Err(err)),
171 Ok(None) => {}
172 }
173 resp
174}
175
176async fn resolve_root<K: EnrKeyUnambiguous, R: Resolver>(
181 resolver: Arc<R>,
182 link: LinkEntry<K>,
183 timeout: Duration,
184) -> ResolveRootResult<K> {
185 let root = match lookup_with_timeout::<R>(&resolver, &link.domain, timeout).await {
186 Ok(Some(root)) => root,
187 Ok(_) => return Err((LookupError::EntryNotFound, link)),
188 Err(err) => return Err((err, link)),
189 };
190
191 match root.parse::<TreeRootEntry>() {
192 Ok(root) => {
193 if root.verify::<K>(&link.pubkey) {
194 Ok((root, link))
195 } else {
196 Err((LookupError::InvalidRoot(root), link))
197 }
198 }
199 Err(err) => Err((err.into(), link)),
200 }
201}
202
203async fn lookup_with_timeout<R: Resolver>(
204 r: &R,
205 query: &str,
206 timeout: Duration,
207) -> LookupResult<Option<String>> {
208 tokio::time::timeout(timeout, r.lookup_txt(query))
209 .await
210 .map_err(|_| LookupError::RequestTimedOut)
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use crate::{resolver::TimeoutResolver, DnsDiscoveryConfig, MapResolver};
217 use std::future::poll_fn;
218
219 #[tokio::test]
220 async fn test_rate_limit() {
221 let resolver = Arc::new(MapResolver::default());
222 let config = DnsDiscoveryConfig::default();
223 let mut pool = QueryPool::new(resolver, config.max_requests_per_sec, config.lookup_timeout);
224
225 let s = "enrtree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3YYHIUV7BVDQ5FDPRT2@nodes.example.org";
226 let entry: LinkEntry = s.parse().unwrap();
227
228 for _n in 0..config.max_requests_per_sec.get() {
229 poll_fn(|cx| {
230 pool.resolve_root(entry.clone());
231 assert_eq!(pool.queued_queries.len(), 1);
232 assert!(pool.rate_limit.poll_ready(cx).is_ready());
233 let _ = pool.poll(cx);
234 assert_eq!(pool.queued_queries.len(), 0);
235 Poll::Ready(())
236 })
237 .await;
238 }
239
240 pool.resolve_root(entry.clone());
241 assert_eq!(pool.queued_queries.len(), 1);
242 poll_fn(|cx| {
243 assert!(pool.rate_limit.poll_ready(cx).is_pending());
244 let _ = pool.poll(cx);
245 assert_eq!(pool.queued_queries.len(), 1);
246 Poll::Ready(())
247 })
248 .await;
249 }
250
251 #[tokio::test]
252 async fn test_timeouts() {
253 let config =
254 DnsDiscoveryConfig { lookup_timeout: Duration::from_millis(500), ..Default::default() };
255 let resolver = Arc::new(TimeoutResolver(config.lookup_timeout * 2));
256 let mut pool = QueryPool::new(resolver, config.max_requests_per_sec, config.lookup_timeout);
257
258 let s = "enrtree://AM5FCQLWIZX2QFPNJAP7VUERCCRNGRHWZG3YYHIUV7BVDQ5FDPRT2@nodes.example.org";
259 let entry: LinkEntry = s.parse().unwrap();
260 pool.resolve_root(entry);
261
262 let outcome = poll_fn(|cx| pool.poll(cx)).await;
263
264 match outcome {
265 QueryOutcome::Root(res) => {
266 let res = res.unwrap_err().0;
267 match res {
268 LookupError::RequestTimedOut => {}
269 _ => unreachable!(),
270 }
271 }
272 QueryOutcome::Entry(_) => {
273 unreachable!()
274 }
275 }
276 }
277}